from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
import matplotlib.pyplot as plt
import numpy as np
import cv2
import os
import pandas as pd
from sklearn.mixture import GaussianMixture
from scipy.stats import multivariate_normal
!unzip "/content/drive/MyDrive/Q8_Dataset.zip"
Archive: /content/drive/MyDrive/Q8_Dataset.zip creating: Q8_Dataset/ inflating: __MACOSX/._Q8_Dataset inflating: Q8_Dataset/.DS_Store inflating: __MACOSX/Q8_Dataset/._.DS_Store creating: Q8_Dataset/Images/ inflating: __MACOSX/Q8_Dataset/._Images inflating: Q8_Dataset/Images/c8.jpg inflating: __MACOSX/Q8_Dataset/Images/._c8.jpg inflating: Q8_Dataset/Images/m3.jpg inflating: __MACOSX/Q8_Dataset/Images/._m3.jpg inflating: Q8_Dataset/Images/c27.jpg inflating: __MACOSX/Q8_Dataset/Images/._c27.jpg inflating: Q8_Dataset/Images/c33.jpg inflating: __MACOSX/Q8_Dataset/Images/._c33.jpg inflating: Q8_Dataset/Images/m28.jpg inflating: __MACOSX/Q8_Dataset/Images/._m28.jpg inflating: Q8_Dataset/Images/m14.jpg inflating: __MACOSX/Q8_Dataset/Images/._m14.jpg inflating: Q8_Dataset/Images/m15.jpg inflating: __MACOSX/Q8_Dataset/Images/._m15.jpg inflating: Q8_Dataset/Images/m29.jpg inflating: __MACOSX/Q8_Dataset/Images/._m29.jpg inflating: Q8_Dataset/Images/c32.jpg inflating: __MACOSX/Q8_Dataset/Images/._c32.jpg inflating: Q8_Dataset/Images/c26.jpg inflating: __MACOSX/Q8_Dataset/Images/._c26.jpg inflating: Q8_Dataset/Images/m2.jpg inflating: __MACOSX/Q8_Dataset/Images/._m2.jpg inflating: Q8_Dataset/Images/c9.jpg inflating: __MACOSX/Q8_Dataset/Images/._c9.jpg inflating: Q8_Dataset/Images/c30.jpg inflating: __MACOSX/Q8_Dataset/Images/._c30.jpg inflating: Q8_Dataset/Images/c24.jpg inflating: __MACOSX/Q8_Dataset/Images/._c24.jpg inflating: Q8_Dataset/Images/c18.jpg inflating: __MACOSX/Q8_Dataset/Images/._c18.jpg inflating: Q8_Dataset/Images/m17.jpg inflating: __MACOSX/Q8_Dataset/Images/._m17.jpg inflating: Q8_Dataset/Images/m16.jpg inflating: __MACOSX/Q8_Dataset/Images/._m16.jpg inflating: Q8_Dataset/Images/c19.jpg inflating: __MACOSX/Q8_Dataset/Images/._c19.jpg inflating: Q8_Dataset/Images/c25.jpg inflating: __MACOSX/Q8_Dataset/Images/._c25.jpg inflating: Q8_Dataset/Images/c31.jpg inflating: __MACOSX/Q8_Dataset/Images/._c31.jpg inflating: Q8_Dataset/Images/m1.jpg inflating: __MACOSX/Q8_Dataset/Images/._m1.jpg inflating: Q8_Dataset/Images/c35.jpg inflating: __MACOSX/Q8_Dataset/Images/._c35.jpg inflating: Q8_Dataset/Images/c21.jpg inflating: __MACOSX/Q8_Dataset/Images/._c21.jpg inflating: Q8_Dataset/Images/m12.jpg inflating: __MACOSX/Q8_Dataset/Images/._m12.jpg inflating: Q8_Dataset/Images/ManUtd-508878051576073641743_medium.jpg inflating: __MACOSX/Q8_Dataset/Images/._ManUtd-508878051576073641743_medium.jpg inflating: Q8_Dataset/Images/index.jpg inflating: __MACOSX/Q8_Dataset/Images/._index.jpg inflating: Q8_Dataset/Images/m13.jpg inflating: __MACOSX/Q8_Dataset/Images/._m13.jpg inflating: Q8_Dataset/Images/c20.jpg inflating: __MACOSX/Q8_Dataset/Images/._c20.jpg inflating: Q8_Dataset/Images/c34.jpg inflating: __MACOSX/Q8_Dataset/Images/._c34.jpg inflating: Q8_Dataset/Images/m4.jpg inflating: __MACOSX/Q8_Dataset/Images/._m4.jpg inflating: Q8_Dataset/Images/c22.jpg inflating: __MACOSX/Q8_Dataset/Images/._c22.jpg inflating: Q8_Dataset/Images/c36.jpg inflating: __MACOSX/Q8_Dataset/Images/._c36.jpg inflating: Q8_Dataset/Images/m11.jpg inflating: __MACOSX/Q8_Dataset/Images/._m11.jpg inflating: Q8_Dataset/Images/m39.jpg inflating: __MACOSX/Q8_Dataset/Images/._m39.jpg inflating: Q8_Dataset/Images/m38.jpg inflating: __MACOSX/Q8_Dataset/Images/._m38.jpg inflating: Q8_Dataset/Images/c37.jpg inflating: __MACOSX/Q8_Dataset/Images/._c37.jpg inflating: Q8_Dataset/Images/c23.jpg inflating: __MACOSX/Q8_Dataset/Images/._c23.jpg inflating: Q8_Dataset/Images/c44.jpg inflating: __MACOSX/Q8_Dataset/Images/._c44.jpg inflating: Q8_Dataset/Images/c50.jpg inflating: __MACOSX/Q8_Dataset/Images/._c50.jpg inflating: Q8_Dataset/Images/m63.jpg inflating: __MACOSX/Q8_Dataset/Images/._m63.jpg inflating: Q8_Dataset/Images/m62.jpg inflating: __MACOSX/Q8_Dataset/Images/._m62.jpg inflating: Q8_Dataset/Images/c51.jpg inflating: __MACOSX/Q8_Dataset/Images/._c51.jpg inflating: Q8_Dataset/Images/c45.jpg inflating: __MACOSX/Q8_Dataset/Images/._c45.jpg inflating: Q8_Dataset/Images/c53.jpg inflating: __MACOSX/Q8_Dataset/Images/._c53.jpg inflating: Q8_Dataset/Images/c47.jpg inflating: __MACOSX/Q8_Dataset/Images/._c47.jpg inflating: Q8_Dataset/Images/m48.jpg inflating: __MACOSX/Q8_Dataset/Images/._m48.jpg inflating: Q8_Dataset/Images/m60.jpg inflating: __MACOSX/Q8_Dataset/Images/._m60.jpg inflating: Q8_Dataset/Images/m61.jpg inflating: __MACOSX/Q8_Dataset/Images/._m61.jpg inflating: Q8_Dataset/Images/m49.jpg inflating: __MACOSX/Q8_Dataset/Images/._m49.jpg inflating: Q8_Dataset/Images/c46.jpg inflating: __MACOSX/Q8_Dataset/Images/._c46.jpg inflating: Q8_Dataset/Images/c52.jpg inflating: __MACOSX/Q8_Dataset/Images/._c52.jpg inflating: Q8_Dataset/Images/c56.jpg inflating: __MACOSX/Q8_Dataset/Images/._c56.jpg inflating: Q8_Dataset/Images/c42.jpg inflating: __MACOSX/Q8_Dataset/Images/._c42.jpg inflating: Q8_Dataset/Images/m59.jpg inflating: __MACOSX/Q8_Dataset/Images/._m59.jpg inflating: Q8_Dataset/Images/m58.jpg inflating: __MACOSX/Q8_Dataset/Images/._m58.jpg inflating: Q8_Dataset/Images/c43.jpg inflating: __MACOSX/Q8_Dataset/Images/._c43.jpg inflating: Q8_Dataset/Images/c57.jpg inflating: __MACOSX/Q8_Dataset/Images/._c57.jpg inflating: Q8_Dataset/Images/c41.jpg inflating: __MACOSX/Q8_Dataset/Images/._c41.jpg inflating: Q8_Dataset/Images/c55.jpg inflating: __MACOSX/Q8_Dataset/Images/._c55.jpg inflating: Q8_Dataset/Images/c54.jpg inflating: __MACOSX/Q8_Dataset/Images/._c54.jpg inflating: Q8_Dataset/Images/c40.jpg inflating: __MACOSX/Q8_Dataset/Images/._c40.jpg inflating: Q8_Dataset/Images/c65.jpg inflating: __MACOSX/Q8_Dataset/Images/._c65.jpg inflating: Q8_Dataset/Images/c59.jpg inflating: __MACOSX/Q8_Dataset/Images/._c59.jpg inflating: Q8_Dataset/Images/m42.jpg inflating: __MACOSX/Q8_Dataset/Images/._m42.jpg inflating: Q8_Dataset/Images/m56.jpg inflating: __MACOSX/Q8_Dataset/Images/._m56.jpg inflating: Q8_Dataset/Images/m57.jpg inflating: __MACOSX/Q8_Dataset/Images/._m57.jpg inflating: Q8_Dataset/Images/m43.jpg inflating: __MACOSX/Q8_Dataset/Images/._m43.jpg inflating: Q8_Dataset/Images/c58.jpg inflating: __MACOSX/Q8_Dataset/Images/._c58.jpg inflating: Q8_Dataset/Images/c64.jpg inflating: __MACOSX/Q8_Dataset/Images/._c64.jpg inflating: Q8_Dataset/Images/m55.jpg inflating: __MACOSX/Q8_Dataset/Images/._m55.jpg inflating: Q8_Dataset/Images/m41.jpg inflating: __MACOSX/Q8_Dataset/Images/._m41.jpg inflating: Q8_Dataset/Images/m40.jpg inflating: __MACOSX/Q8_Dataset/Images/._m40.jpg inflating: Q8_Dataset/Images/m54.jpg inflating: __MACOSX/Q8_Dataset/Images/._m54.jpg inflating: Q8_Dataset/Images/m50.jpg inflating: __MACOSX/Q8_Dataset/Images/._m50.jpg inflating: Q8_Dataset/Images/m44.jpg inflating: __MACOSX/Q8_Dataset/Images/._m44.jpg inflating: Q8_Dataset/Images/m45.jpg inflating: __MACOSX/Q8_Dataset/Images/._m45.jpg inflating: Q8_Dataset/Images/m51.jpg inflating: __MACOSX/Q8_Dataset/Images/._m51.jpg inflating: Q8_Dataset/Images/c48.jpg inflating: __MACOSX/Q8_Dataset/Images/._c48.jpg inflating: Q8_Dataset/Images/c60.jpg inflating: __MACOSX/Q8_Dataset/Images/._c60.jpg inflating: Q8_Dataset/Images/m47.jpg inflating: __MACOSX/Q8_Dataset/Images/._m47.jpg inflating: Q8_Dataset/Images/m53.jpg inflating: __MACOSX/Q8_Dataset/Images/._m53.jpg inflating: Q8_Dataset/Images/m52.jpg inflating: __MACOSX/Q8_Dataset/Images/._m52.jpg inflating: Q8_Dataset/Images/m46.jpg inflating: __MACOSX/Q8_Dataset/Images/._m46.jpg inflating: Q8_Dataset/Images/c61.jpg inflating: __MACOSX/Q8_Dataset/Images/._c61.jpg inflating: Q8_Dataset/Images/c49.jpg inflating: __MACOSX/Q8_Dataset/Images/._c49.jpg inflating: Q8_Dataset/Images/c1.jpg inflating: __MACOSX/Q8_Dataset/Images/._c1.jpg inflating: Q8_Dataset/Images/c12.jpg inflating: __MACOSX/Q8_Dataset/Images/._c12.jpg inflating: Q8_Dataset/Images/m21.jpg inflating: __MACOSX/Q8_Dataset/Images/._m21.jpg inflating: Q8_Dataset/Images/m34.jpg inflating: __MACOSX/Q8_Dataset/Images/._m34.jpg inflating: Q8_Dataset/Images/m20.jpg inflating: __MACOSX/Q8_Dataset/Images/._m20.jpg inflating: Q8_Dataset/Images/c13.jpg inflating: __MACOSX/Q8_Dataset/Images/._c13.jpg inflating: Q8_Dataset/Images/c2.jpg inflating: __MACOSX/Q8_Dataset/Images/._c2.jpg inflating: Q8_Dataset/Images/m9.jpg inflating: __MACOSX/Q8_Dataset/Images/._m9.jpg inflating: Q8_Dataset/Images/c11.jpg inflating: __MACOSX/Q8_Dataset/Images/._c11.jpg inflating: Q8_Dataset/Images/c39.jpg inflating: __MACOSX/Q8_Dataset/Images/._c39.jpg inflating: Q8_Dataset/Images/m22.jpg inflating: __MACOSX/Q8_Dataset/Images/._m22.jpg inflating: Q8_Dataset/Images/m23.jpg inflating: __MACOSX/Q8_Dataset/Images/._m23.jpg inflating: Q8_Dataset/Images/c38.jpg inflating: __MACOSX/Q8_Dataset/Images/._c38.jpg inflating: Q8_Dataset/Images/c10.jpg inflating: __MACOSX/Q8_Dataset/Images/._c10.jpg inflating: Q8_Dataset/Images/c3.jpg inflating: __MACOSX/Q8_Dataset/Images/._c3.jpg inflating: Q8_Dataset/Images/c7.jpg inflating: __MACOSX/Q8_Dataset/Images/._c7.jpg inflating: Q8_Dataset/Images/c28.jpg inflating: __MACOSX/Q8_Dataset/Images/._c28.jpg inflating: Q8_Dataset/Images/c14.jpg inflating: __MACOSX/Q8_Dataset/Images/._c14.jpg inflating: Q8_Dataset/Images/m33.jpg inflating: __MACOSX/Q8_Dataset/Images/._m33.jpg inflating: Q8_Dataset/Images/m27.jpg inflating: __MACOSX/Q8_Dataset/Images/._m27.jpg inflating: Q8_Dataset/Images/m26.jpg inflating: __MACOSX/Q8_Dataset/Images/._m26.jpg inflating: Q8_Dataset/Images/m32.jpg inflating: __MACOSX/Q8_Dataset/Images/._m32.jpg inflating: Q8_Dataset/Images/c15.jpg inflating: __MACOSX/Q8_Dataset/Images/._c15.jpg inflating: Q8_Dataset/Images/c29.jpg inflating: __MACOSX/Q8_Dataset/Images/._c29.jpg inflating: Q8_Dataset/Images/c6.jpg inflating: __MACOSX/Q8_Dataset/Images/._c6.jpg inflating: Q8_Dataset/Images/c4.jpg inflating: __MACOSX/Q8_Dataset/Images/._c4.jpg inflating: Q8_Dataset/Images/1.jpg inflating: __MACOSX/Q8_Dataset/Images/._1.jpg inflating: Q8_Dataset/Images/images.jpg inflating: __MACOSX/Q8_Dataset/Images/._images.jpg inflating: Q8_Dataset/Images/c17.jpg inflating: __MACOSX/Q8_Dataset/Images/._c17.jpg inflating: Q8_Dataset/Images/m24.jpg inflating: __MACOSX/Q8_Dataset/Images/._m24.jpg inflating: Q8_Dataset/Images/m30.jpg inflating: __MACOSX/Q8_Dataset/Images/._m30.jpg inflating: Q8_Dataset/Images/m18.jpg inflating: __MACOSX/Q8_Dataset/Images/._m18.jpg inflating: Q8_Dataset/Images/m19.jpg inflating: __MACOSX/Q8_Dataset/Images/._m19.jpg inflating: Q8_Dataset/Images/m31.jpg inflating: __MACOSX/Q8_Dataset/Images/._m31.jpg inflating: Q8_Dataset/Images/m25.jpg inflating: __MACOSX/Q8_Dataset/Images/._m25.jpg inflating: Q8_Dataset/Images/c16.jpg inflating: __MACOSX/Q8_Dataset/Images/._c16.jpg inflating: Q8_Dataset/Images/c5.jpg inflating: __MACOSX/Q8_Dataset/Images/._c5.jpg
base_folder = '/content/Q8_Dataset/Images'
images = []
folder = os.listdir(base_folder)
for im in folder:
image = cv2.imread(os.path.join(base_folder, im))
image = cv2.resize(image,(112,112))
images.append(image)
images = np.array(images)
features = []
Team = []
for image in images:
channels = cv2.mean(image)
observation = np.array([(channels[2], channels[1], channels[0])])
img_mean = observation[0]
if(img_mean[2] > img_mean[0]):
features.append([img_mean[0], img_mean[2]])
Team.append(1)
#chelsea
else:
features.append([img_mean[0], img_mean[2]])
Team.append(0)
#manchester
X = pd.DataFrame(features,columns=['R','B'])
y = pd.DataFrame(Team,columns=['class'])
print(X,y)
R B 0 144.047991 108.644531 1 97.178332 65.903300 2 36.803651 33.594228 3 144.032127 116.515545 4 92.955835 100.023677 .. ... ... 117 72.367427 71.396604 118 58.110172 71.312500 119 111.657127 45.878827 120 69.949857 66.530692 121 50.770488 59.662548 [122 rows x 2 columns] class 0 0 1 0 2 0 3 0 4 1 .. ... 117 0 118 1 119 0 120 0 121 1 [122 rows x 1 columns]
X_manchester = X[y['class'] == 0]
X_chelsea = X[y['class'] == 1]
def gmm_model(X,k,aic_bic):
GMM = GaussianMixture(n_components=k, covariance_type='full').fit(X)
print('Converged:',GMM.converged_)
means = GMM.means_
covariances = GMM.covariances_
print('\u03BC = ', means, sep="\n")
print('\u03A3 = ', covariances, sep="\n")
if(aic_bic):
return means,covariances,GMM.aic(X), GMM.bic(X)
return means,covariances
print('Team manchester:')
means_manchester, covariances_manchester = gmm_model(X_manchester,k = 2, aic_bic = 0)
Team manchester: Converged: True μ = [[140.66279131 89.8697372 ] [ 87.85975399 60.3903072 ]] Σ = [[[179.90559032 118.75390519] [118.75390519 463.93832578]] [[517.81964788 231.4196045 ] [231.4196045 275.8825617 ]]]
print('Team chelsea:')
means_chelsea, covariances_chelsea = gmm_model(X_chelsea,k = 2, aic_bic = 0)
Team chelsea: Converged: True μ = [[ 75.40077098 86.87922683] [106.99191421 137.20899606]] Σ = [[[ 420.48466332 396.21947624] [ 396.21947624 424.74900909]] [[1045.04163064 719.85707872] [ 719.85707872 636.67987286]]]
def draw_plot(X,means,covariances):
x,y = np.meshgrid(np.sort(X[:,0]),np.sort(X[:,1]))
XY = np.array([x.flatten(),y.flatten()]).T
# Plot
fig = plt.figure(figsize=(10,6))
ax0 = fig.add_subplot(111)
ax0.scatter(X[:,0],X[:,1])
for m,c in zip(means,covariances):
multi_normal = multivariate_normal(mean=m,cov=c)
ax0.contour(np.sort(X[:,0]),np.sort(X[:,1]),multi_normal.pdf(XY).reshape(len(X),len(X)),colors='green',alpha=0.3)
ax0.scatter(m[0],m[1],c='red',zorder=10,s=100)
plt.show()
draw_plot(X_manchester.to_numpy(),means_manchester,covariances_manchester)
draw_plot(X_chelsea.to_numpy(),means_chelsea,covariances_chelsea)
AIC = []
BIC = []
K = range(1,20)
for k in K:
means_temp, covariance_temp, aic, bic= gmm_model(X_manchester,k = k, aic_bic = 1)
draw_plot(X_manchester.to_numpy(), means_temp,covariance_temp)
AIC.append(aic)
BIC.append(bic)
plt.plot(K, AIC, 'b', label='aic')
plt.title("AIC of team Manchester")
plt.show()
plt.plot(K, BIC, 'b', label='bic')
plt.title("BIC of team Manchester")
plt.show()
sum_ab = [i + j for i,j in zip(AIC,BIC)]
best_k = sum_ab.index(min(sum_ab)) + 1
print("best k : ", best_k)
Converged: True μ = [[97.40187287 65.71758078]] Σ = [[[869.55550034 441.52235878] [441.52235878 438.53155565]]]
Converged: True μ = [[ 87.85975399 60.3903072 ] [140.66279131 89.8697372 ]] Σ = [[[517.81964788 231.4196045 ] [231.4196045 275.8825617 ]] [[179.90559032 118.75390519] [118.75390519 463.93832578]]]
Converged: True μ = [[105.02941197 69.31005207] [ 72.53014338 50.71747264] [139.84759667 92.95369038]] Σ = [[[106.55947092 -87.58176895] [-87.58176895 147.03791534]] [[265.60113379 94.77841822] [ 94.77841822 152.28576319]] [[180.70459049 42.497404 ] [ 42.497404 284.36503122]]]
Converged: True μ = [[129.98290771 99.13287173] [ 72.22413336 50.54257704] [105.0267346 69.11560614] [149.7111139 84.39745339]] Σ = [[[140.63257257 158.43236945] [158.43236945 224.98463214]] [[258.47564708 90.41236069] [ 90.41236069 149.54740699]] [[109.72136102 -89.61515148] [-89.61515148 148.40685935]] [[ 85.07263567 99.49055688] [ 99.49055688 231.17898111]]]
Converged: True μ = [[ 55.16525673 44.26838963] [138.3296215 101.41457294] [103.55952452 69.26544588] [ 78.6166853 52.68189309] [136.26639209 67.19714379]] Σ = [[[ 80.0483634 53.52830863] [ 53.52830863 82.99035437]] [[200.55160362 78.73391773] [ 78.73391773 122.84326919]] [[ 96.39380995 -68.11611677] [-68.11611677 142.57094897]] [[ 61.44897859 -3.37828893] [ -3.37828893 152.9667958 ]] [[225.12014764 131.69783173] [131.69783173 85.42377974]]]
Converged: True μ = [[ 78.0590491 52.68017538] [ 99.24570404 79.51398379] [138.94484017 101.78358853] [127.70395444 60.13005972] [ 56.35910337 45.55671272] [101.18022153 63.7516573 ]] Σ = [[[ 32.82298073 -2.25206199] [ -2.25206199 171.42998086]] [[ 77.27503101 -35.62722641] [-35.62722641 28.71364816]] [[190.32950661 68.8697043 ] [ 68.8697043 118.09872294]] [[234.77061405 150.00061843] [150.00061843 121.47643673]] [[ 87.48295217 64.18447333] [ 64.18447333 96.64142354]] [[ 40.01173112 49.7944553 ] [ 49.7944553 77.63647791]]]
Converged: True μ = [[ 62.1816686 41.05328749] [138.51995391 101.54708055] [103.03573855 77.05503661] [116.45051925 48.5264943 ] [ 90.36911871 56.20297493] [133.25918609 65.90787412] [ 72.23166296 67.06772898]] Σ = [[[207.15206114 27.8952008 ] [ 27.8952008 74.28648218]] [[195.52227572 73.3948661 ] [ 73.3948661 119.06744127]] [[ 46.72074841 -25.62819202] [-25.62819202 23.92242874]] [[ 51.34628599 -4.33809788] [ -4.33809788 5.96609946]] [[135.52547553 80.46434559] [ 80.46434559 75.49599986]] [[230.3221053 126.25263937] [126.25263937 75.57762078]] [[127.21946636 144.04209296] [144.04209296 172.6877565 ]]]
Converged: True μ = [[149.09723222 108.67832626] [106.50789047 68.79689888] [ 88.75014152 51.48651082] [ 67.78445879 51.35970648] [138.90481945 68.8983692 ] [ 76.31859527 72.10719011] [126.36813543 93.66079967] [ 50.45428044 37.92395717]] Σ = [[[ 6.69917876e+01 -5.83536661e+01] [-5.83536661e+01 8.11387670e+01]] [[ 8.33068437e+01 -8.19950534e+01] [-8.19950534e+01 1.48713914e+02]] [[ 9.60753301e+01 9.94929512e+01] [ 9.94929512e+01 1.51201795e+02]] [[ 6.48268888e+01 -1.50232784e-01] [-1.50232784e-01 3.28957146e+00]] [[ 1.88864821e+02 1.06557999e+02] [ 1.06557999e+02 6.83741708e+01]] [[ 7.88938194e+01 7.76920970e+01] [ 7.76920970e+01 8.47443765e+01]] [[ 6.87416500e+01 4.55958592e+01] [ 4.55958592e+01 3.79365299e+01]] [[ 5.86867953e+01 9.76290172e+00] [ 9.76290172e+00 2.92621731e+01]]]
Converged: True μ = [[133.12078371 65.78534691] [ 79.16304647 68.58304747] [ 56.1019943 45.51300894] [149.08992459 108.66821181] [101.20506259 63.78931943] [ 77.60637865 46.91635159] [ 98.78905259 79.68755641] [126.6428572 93.83285063] [116.6028723 48.4012733 ]] Σ = [[[230.0701672 127.70222096] [127.70222096 77.39027126]] [[ 38.00641652 -25.078988 ] [-25.078988 18.01518424]] [[ 82.46935565 61.97598364] [ 61.97598364 96.04795268]] [[ 66.90936601 -58.2889405 ] [-58.2889405 81.20064606]] [[ 34.09681613 42.08145156] [ 42.08145156 67.28860573]] [[ 29.36506772 -6.58140504] [ -6.58140504 94.22082674]] [[ 79.54406196 -34.44365063] [-34.44365063 27.44034243]] [[ 63.52091875 42.02017864] [ 42.02017864 35.43059739]] [[ 51.28464738 -3.89382586] [ -3.89382586 5.67395306]]]
Converged: True μ = [[102.72264289 77.71779781] [ 67.41717871 51.30012134] [144.32976073 71.92212753] [126.51601955 93.76488969] [ 81.33629497 45.18135899] [117.41033097 52.78989752] [ 76.61761656 72.53641558] [ 50.39937523 37.85886583] [149.07361803 108.68015565] [100.48329807 62.81563312]] Σ = [[[ 55.93118051 -29.11841357] [-29.11841357 23.73546084]] [[ 63.08729067 -0.7574916 ] [ -0.7574916 3.24766959]] [[ 97.72575284 53.88064054] [ 53.88064054 38.89222914]] [[ 62.70326144 41.00923548] [ 41.00923548 34.50546998]] [[ 23.72307818 34.0494433 ] [ 34.0494433 110.97076081]] [[ 31.92286479 6.36158845] [ 6.36158845 29.3070504 ]] [[ 85.30918111 88.24812489] [ 88.24812489 99.44806466]] [[ 58.3661745 9.35931333] [ 9.35931333 28.86464872]] [[ 66.81591235 -58.07410965] [-58.07410965 80.72485727]] [[ 27.22703076 33.13141649] [ 33.13141649 56.49002422]]]
Converged: True μ = [[117.33044153 52.69828205] [ 49.65281653 36.73783629] [142.32076854 116.5313269 ] [101.3251806 78.43120601] [ 58.8853133 50.79740096] [ 97.3761306 60.83787452] [144.61426191 72.0703035 ] [126.33989024 93.62416504] [ 74.76814246 70.12639387] [155.63045052 101.02636121] [ 76.86407339 45.36845697]] Σ = [[[ 3.18959618e+01 5.71569223e+00] [ 5.71569223e+00 2.86847806e+01]] [[ 5.95251431e+01 3.46107055e+00] [ 3.46107055e+00 2.15885546e+01]] [[ 5.75742231e+00 -1.29023691e+01] [-1.29023691e+01 3.90491961e+01]] [[ 7.30790256e+01 -4.21729693e+01] [-4.21729693e+01 3.50640207e+01]] [[ 1.30874487e+01 4.73689411e+00] [ 4.73689411e+00 6.57837620e+00]] [[ 5.44499233e+01 4.26447543e+01] [ 4.26447543e+01 5.87170955e+01]] [[ 9.30465718e+01 5.16369289e+01] [ 5.16369289e+01 3.79347413e+01]] [[ 6.00617066e+01 3.86351777e+01] [ 3.86351777e+01 3.24832992e+01]] [[ 5.14413748e+01 4.11094374e+01] [ 4.11094374e+01 3.77610150e+01]] [[ 3.81246075e+01 8.71637357e-01] [ 8.71637357e-01 1.99303947e-02]] [[ 2.31365510e+01 -5.94640140e+00] [-5.94640140e+00 7.44169040e+01]]]
Converged: True μ = [[142.34592519 69.5145813 ] [ 68.58254141 51.18990174] [ 98.58096322 79.80743837] [160.6303498 91.87829098] [117.69783847 53.22418418] [145.99175119 110.24528544] [ 82.08693089 45.21287556] [126.47861685 93.75576843] [ 74.24614574 69.93600503] [ 50.41742224 38.06770968] [103.40544536 68.1240848 ] [ 95.10695442 52.73738042]] Σ = [[[ 34.41073984 3.65367307] [ 3.65367307 0.54834039]] [[ 62.9884775 -9.17545166] [ -9.17545166 17.1602299 ]] [[ 78.39657133 -33.31268686] [-33.31268686 26.85232133]] [[ 12.78929137 33.41779424] [ 33.41779424 87.31907338]] [[ 32.06300672 8.5841894 ] [ 8.5841894 31.65314177]] [[ 24.10777048 -42.14273982] [-42.14273982 82.91629114]] [[ 19.97735454 55.159002 ] [ 55.159002 162.8768246 ]] [[ 64.48908645 42.2823167 ] [ 42.2823167 35.39699181]] [[ 11.51138795 3.5653883 ] [ 3.5653883 4.64171965]] [[ 57.69986191 10.26245935] [ 10.26245935 31.97840975]] [[ 26.82302917 22.26736525] [ 22.26736525 23.15979521]] [[ 41.58015078 46.84478039] [ 46.84478039 54.76950381]]]
Converged: True μ = [[ 71.75613064 49.67709828] [142.34592468 69.51458112] [103.56832999 68.02788952] [145.99094249 110.24462547] [ 98.31708475 79.99671699] [ 56.44637486 45.38746106] [126.46600258 93.74604292] [117.7034785 53.23266441] [ 95.09583501 52.58841441] [160.63034984 91.87829106] [ 47.36114341 35.92210918] [ 83.18078544 46.47939742] [ 74.24980257 69.93907994]] Σ = [[[ 56.28846511 -23.11878979] [-23.11878979 29.42895099]] [[ 34.41073763 3.65367339] [ 3.65367339 0.54834049]] [[ 26.40987529 21.33318293] [ 21.33318293 23.29738033]] [[ 24.09932932 -42.12395181] [-42.12395181 82.88292948]] [[ 76.71392279 -32.05903033] [-32.05903033 25.99893399]] [[ 8.93519273 -13.86713808] [-13.86713808 23.3019297 ]] [[ 64.49284567 42.26149897] [ 42.26149897 35.36577145]] [[ 32.06633832 8.62301595] [ 8.62301595 31.70423339]] [[ 32.84854396 35.19147996] [ 35.19147996 39.65002487]] [[ 12.78929137 33.41779424] [ 33.41779424 87.31907338]] [[ 43.01693354 -4.9157283 ] [ -4.9157283 21.1932472 ]] [[ 23.59488127 73.18323744] [ 73.18323744 227.71532511]] [[ 11.49833717 3.54918974] [ 3.54918974 4.62848046]]]
Converged: True μ = [[105.99423792 75.42689943] [ 83.19501474 46.57247051] [126.31954532 93.61192693] [ 58.82480448 50.82917951] [ 99.65788227 61.63129317] [145.72448826 72.6231229 ] [127.4946588 47.11806441] [ 90.67260674 83.92221512] [155.63045054 101.02636122] [ 74.2590226 69.94635366] [ 49.70461403 36.79825675] [ 75.45641084 48.47348593] [116.02615556 54.18100198] [142.32063436 116.53194137]] Σ = [[[ 2.83291436e+01 -9.88589717e+00] [-9.88589717e+00 1.16342987e+01]] [[ 2.43320471e+01 7.54649848e+01] [ 7.54649848e+01 2.34735848e+02]] [[ 6.04906114e+01 3.88701502e+01] [ 3.88701502e+01 3.25838695e+01]] [[ 1.27808000e+01 4.66908406e+00] [ 4.66908406e+00 6.59085147e+00]] [[ 2.56981132e+01 3.07625848e+01] [ 3.07625848e+01 5.30779380e+01]] [[ 7.21074324e+01 4.13868210e+01] [ 4.13868210e+01 3.33695261e+01]] [[ 1.00000000e-06 2.82727748e-26] [ 2.82727748e-26 1.00000000e-06]] [[ 1.52831623e+01 5.22646536e-01] [ 5.22646536e-01 1.42377873e+01]] [[ 3.81246075e+01 8.71637357e-01] [ 8.71637357e-01 1.99303947e-02]] [[ 1.14783511e+01 3.52070717e+00] [ 3.52070717e+00 4.61031368e+00]] [[ 5.96041053e+01 3.78013055e+00] [ 3.78013055e+00 2.19587346e+01]] [[ 3.41967520e+01 -2.36226988e+01] [-2.36226988e+01 3.71398956e+01]] [[ 1.83054254e+01 2.15156663e+01] [ 2.15156663e+01 2.92799709e+01]] [[ 5.75763741e+00 -1.29023092e+01] [-1.29023092e+01 3.90473812e+01]]]
Converged: True μ = [[149.76534993 102.81772811] [ 74.03618675 49.85739725] [117.46399903 52.8643021 ] [ 90.15613015 84.53938454] [ 54.38334764 35.71934161] [123.78345384 91.55077443] [ 98.65413831 60.16349923] [144.36197015 71.94407385] [ 59.66767001 52.26058982] [ 82.40074261 44.5392492 ] [108.55267686 72.69360499] [141.50199298 120.26992986] [ 39.67482451 38.05636143] [ 74.26961967 69.95476367] [101.07256699 78.96368183]] Σ = [[[ 80.56234954 -11.41668178] [-11.41668178 8.86573355]] [[ 10.64518144 2.09747565] [ 2.09747565 10.49815392]] [[ 31.99743799 6.79339786] [ 6.79339786 29.74173094]] [[ 13.20548292 4.2265603 ] [ 4.2265603 11.38459826]] [[ 13.9267524 8.25046114] [ 8.25046114 18.06500602]] [[ 31.96049538 14.03547375] [ 14.03547375 13.34276394]] [[ 21.23925894 23.79808608] [ 23.79808608 43.48793942]] [[ 97.52053868 53.84321009] [ 53.84321009 38.92362015]] [[ 15.90310283 13.27053712] [ 13.27053712 20.98098342]] [[ 16.8597319 42.75233725] [ 42.75233725 132.64133128]] [[ 13.83516247 1.66605626] [ 1.66605626 2.78499639]] [[ 6.4015787 -9.49909577] [ -9.49909577 14.09540447]] [[ 8.24363809 12.81155872] [ 12.81155872 19.91063451]] [[ 11.45503629 3.48757664] [ 3.48757664 4.58685461]] [[ 15.60391086 2.95717297] [ 2.95717297 0.59493124]]]
Converged: True μ = [[ 47.36099959 35.92173127] [105.02663266 75.98796837] [134.65951945 99.60791391] [ 75.99708903 49.57337002] [ 64.96941974 52.87208158] [138.1983817 69.07920121] [117.70928128 53.23695967] [ 89.95896722 84.54705698] [155.63042092 101.02636054] [ 74.24544624 69.93566126] [ 82.38495466 44.72369879] [ 56.43223225 45.41481185] [ 99.83711761 61.8706474 ] [118.54508472 89.08622655] [141.50199298 120.26992985] [153.84705835 76.45950255]] Σ = [[[ 4.30149409e+01 -4.92203472e+00] [-4.92203472e+00 2.11888809e+01]] [[ 3.61149305e+01 -1.31860020e+01] [-1.31860020e+01 1.27250369e+01]] [[ 4.31021777e+01 4.33232896e+01] [ 4.33232896e+01 4.50231419e+01]] [[ 2.88495592e+00 5.30278853e+00] [ 5.30278853e+00 1.37713413e+01]] [[ 1.61812173e+01 -3.63641761e+00] [-3.63641761e+00 1.39291017e+01]] [[ 1.61788769e-02 6.40980563e-02] [ 6.40980563e-02 2.53962681e-01]] [[ 3.21016304e+01 8.68717775e+00] [ 8.68717775e+00 3.16542522e+01]] [[ 1.31062019e+01 4.36761683e+00] [ 4.36761683e+00 1.22565209e+01]] [[ 3.81245788e+01 8.71636665e-01] [ 8.71636665e-01 1.99303781e-02]] [[ 1.14993050e+01 3.55473721e+00] [ 3.55473721e+00 4.63155803e+00]] [[ 1.73352725e+01 4.43954915e+01] [ 4.43954915e+01 1.36400287e+02]] [[ 8.92214543e+00 -1.38634751e+01] [-1.38634751e+01 2.33243987e+01]] [[ 3.08734820e+01 3.78752990e+01] [ 3.78752990e+01 6.20781598e+01]] [[ 6.14620248e+00 -4.01814950e+00] [-4.01814950e+00 2.62691218e+00]] [[ 6.40157870e+00 -9.49909577e+00] [-9.49909577e+00 1.40954045e+01]] [[ 1.02853061e+01 1.94807071e+01] [ 1.94807071e+01 3.68971029e+01]]]
Converged: True μ = [[142.32072835 116.53151173] [ 74.01841048 49.86216489] [106.28317803 75.18031865] [142.34558407 69.51453559] [ 54.38356186 35.70435173] [126.31560699 93.61099708] [ 99.23960712 61.02732381] [120.25550905 55.40415648] [ 90.89740343 83.72128732] [ 74.27064392 69.95554723] [ 59.66223949 52.25095637] [155.63045054 101.02636122] [ 39.67482451 38.05636143] [ 75.81401467 23.79536033] [111.33645719 47.90006221] [157.05412946 82.53380102] [ 83.50548465 48.06496776]] Σ = [[[ 5.75748093e+00 -1.29023365e+01] [-1.29023365e+01 3.90486172e+01]] [[ 1.05210800e+01 2.13889799e+00] [ 2.13889799e+00 1.05023739e+01]] [[ 2.52283483e+01 -8.73266104e+00] [-8.73266104e+00 1.15999793e+01]] [[ 3.44093579e+01 3.65357079e+00] [ 3.65357079e+00 5.48340691e-01]] [[ 1.39378636e+01 8.24762157e+00] [ 8.24762157e+00 1.79507715e+01]] [[ 6.04845438e+01 3.88417889e+01] [ 3.88417889e+01 3.25523935e+01]] [[ 2.48522697e+01 2.93885707e+01] [ 2.93885707e+01 5.12171598e+01]] [[ 2.20114939e+01 -6.82770063e+00] [-6.82770063e+00 2.72280567e+01]] [[ 1.60884480e+01 -6.31839631e-01] [-6.31839631e-01 1.48251959e+01]] [[ 1.14543027e+01 3.48580691e+00] [ 3.48580691e+00 4.58580185e+00]] [[ 1.59431442e+01 1.33476656e+01] [ 1.33476656e+01 2.11047583e+01]] [[ 3.81246075e+01 8.71637357e-01] [ 8.71637357e-01 1.99303947e-02]] [[ 8.24363809e+00 1.28115587e+01] [ 1.28115587e+01 1.99106345e+01]] [[ 1.00000000e-06 8.48183245e-27] [ 8.48183245e-27 1.00000000e-06]] [[ 1.02831042e-01 -6.48155204e-01] [-6.48155204e-01 4.08543323e+00]] [[ 1.00000000e-06 6.30078982e-26] [ 6.30078982e-26 1.00000000e-06]] [[ 1.10890447e+01 2.25661450e+01] [ 2.25661450e+01 6.81702940e+01]]]
Converged: True μ = [[ 96.77635823 62.59840758] [123.60391522 91.4090081 ] [ 54.36387477 35.86915037] [ 75.45642984 49.09569411] [141.50199298 120.26992985] [124.03635204 62.08450255] [ 74.66561811 70.02232804] [ 92.89725938 83.62611651] [106.89841327 74.61488901] [116.6326457 51.77559493] [153.84705835 76.45950255] [ 58.76068186 50.70605502] [ 89.10907634 47.31416341] [140.66188078 105.49710006] [ 39.67482445 38.05636135] [155.63042095 101.02636054] [138.1983817 69.07920121] [ 75.81401467 23.79536033]] Σ = [[[ 5.89123836e+01 2.75242151e+01] [ 2.75242151e+01 1.91869193e+01]] [[ 3.01526478e+01 1.24215881e+01] [ 1.24215881e+01 1.20569805e+01]] [[ 1.36930945e+01 8.06587275e+00] [ 8.06587275e+00 1.89817158e+01]] [[ 1.99032625e+01 -3.39862309e+00] [-3.39862309e+00 1.17125087e+01]] [[ 6.40157870e+00 -9.49909577e+00] [-9.49909577e+00 1.40954045e+01]] [[ 1.00000000e-06 3.64516847e-26] [ 3.64516847e-26 1.00000000e-06]] [[ 5.04482510e+01 4.02901737e+01] [ 4.02901737e+01 3.71288514e+01]] [[ 1.77199087e+01 -1.00839413e+01] [-1.00839413e+01 1.68992747e+01]] [[ 1.69664596e+01 -5.13969594e+00] [-5.13969594e+00 1.13296763e+01]] [[ 3.00720444e+01 -8.95754914e-01] [-8.95754914e-01 2.17179484e+01]] [[ 1.02853061e+01 1.94807071e+01] [ 1.94807071e+01 3.68971029e+01]] [[ 1.27178876e+01 4.94853758e+00] [ 4.94853758e+00 6.91597657e+00]] [[ 5.93342781e+01 4.86486688e+01] [ 4.86486688e+01 4.23960724e+01]] [[ 1.14658154e+01 1.06576156e+01] [ 1.06576156e+01 9.90638590e+00]] [[ 8.24363809e+00 1.28115587e+01] [ 1.28115587e+01 1.99106345e+01]] [[ 3.81245788e+01 8.71636666e-01] [ 8.71636666e-01 1.99303781e-02]] [[ 1.61788769e-02 6.40980563e-02] [ 6.40980563e-02 2.53962681e-01]] [[ 1.00000000e-06 9.08767763e-27] [ 9.08767763e-27 1.00000000e-06]]]
Converged: True μ = [[ 84.7460528 49.96489525] [123.61375811 91.41517032] [120.2599985 55.39466052] [140.66188053 105.49709983] [ 54.38356701 35.70435762] [ 74.27057839 69.95551109] [101.07101448 78.96330236] [142.34558405 69.51453559] [ 74.17562875 49.80733223] [ 39.67482451 38.05636143] [ 98.63551713 60.13115007] [155.63042095 101.02636054] [ 76.7893779 31.16280058] [ 59.66110725 52.2518137 ] [ 90.16160068 84.53862999] [108.53799684 72.6960605 ] [141.50199298 120.26992985] [111.3364572 47.90006215] [157.05412946 82.53380102]] Σ = [[[ 4.93129106e+00 1.35526617e+01] [ 1.35526617e+01 6.11573640e+01]] [[ 2.99870043e+01 1.23139862e+01] [ 1.23139862e+01 1.19917260e+01]] [[ 2.20043436e+01 -6.75495519e+00] [-6.75495519e+00 2.70539617e+01]] [[ 1.14658154e+01 1.06576156e+01] [ 1.06576156e+01 9.90638590e+00]] [[ 1.39378520e+01 8.24757974e+00] [ 8.24757974e+00 1.79507861e+01]] [[ 1.14534101e+01 3.48503865e+00] [ 3.48503865e+00 4.58521323e+00]] [[ 1.56101102e+01 2.95863386e+00] [ 2.95863386e+00 5.95261419e-01]] [[ 3.44093587e+01 3.65357087e+00] [ 3.65357087e+00 5.48340699e-01]] [[ 1.15948854e+01 1.65034421e+00] [ 1.65034421e+00 1.04109977e+01]] [[ 8.24363809e+00 1.28115587e+01] [ 1.28115587e+01 1.99106345e+01]] [[ 2.00380468e+01 2.21047584e+01] [ 2.21047584e+01 4.12364716e+01]] [[ 3.81245788e+01 8.71636665e-01] [ 8.71636665e-01 1.99303781e-02]] [[ 9.51341217e-01 7.18598157e+00] [ 7.18598157e+00 5.42795639e+01]] [[ 1.59472133e+01 1.33564665e+01] [ 1.33564665e+01 2.11130898e+01]] [[ 1.32086627e+01 4.21752653e+00] [ 4.21752653e+00 1.13696266e+01]] [[ 1.37485106e+01 1.68917199e+00] [ 1.68917199e+00 2.81016905e+00]] [[ 6.40157870e+00 -9.49909577e+00] [-9.49909577e+00 1.40954045e+01]] [[ 1.02831042e-01 -6.48155204e-01] [-6.48155204e-01 4.08543323e+00]] [[ 1.00000000e-06 6.30078982e-26] [ 6.30078982e-26 1.00000000e-06]]]
best k : 18
AIC = []
BIC = []
K = range(1,20)
for k in K:
means_temp, covariance_temp, aic, bic= gmm_model(X_chelsea,k = k, aic_bic = 1)
draw_plot(X_chelsea.to_numpy(), means_temp,covariance_temp)
AIC.append(aic)
BIC.append(bic)
plt.plot(K, AIC, 'b', label='aic')
plt.title("AIC of team Chelsea")
plt.show()
plt.plot(K, BIC, 'b', label='bic')
plt.title("BIC of team Chelsea")
plt.show()
sum_ab = [i + j for i,j in zip(AIC,BIC)]
best_k = sum_ab.index(min(sum_ab)) + 1
print("best k : ", best_k)
Converged: True μ = [[83.06947799 99.09670979]] Σ = [[[755.54917135 767.05402072] [767.05402072 941.83114614]]]
Converged: True μ = [[ 99.25199602 125.3506001 ] [ 73.15122139 83.00571455]] Σ = [[[993.04190816 809.26265148] [809.26265148 836.31610702]] [[351.1158206 321.19693318] [321.19693318 325.12990467]]]
Converged: True μ = [[ 62.57289816 74.27354222] [ 99.5324098 125.34666842] [ 91.76761339 96.79079353]] Σ = [[[162.22127199 182.82930034] [182.82930034 230.54390737]] [[978.97194625 791.94034939] [791.94034939 803.60609035]] [[ 30.57889029 12.63904204] [ 12.63904204 11.70385812]]]
Converged: True μ = [[ 90.13046955 97.2484207 ] [104.79594285 131.01933155] [ 43.25858291 57.35917749] [ 62.55292454 74.36793862]] Σ = [[[ 38.148458 8.81165354] [ 8.81165354 14.15996686]] [[791.56383557 573.83873092] [573.83873092 566.98537671]] [[ 5.32971296 9.82655144] [ 9.82655144 136.63320386]] [[ 71.63578963 69.14999952] [ 69.14999952 86.98747455]]]
Converged: True μ = [[ 59.4374927 71.87843234] [105.69281971 132.46548931] [ 91.10862006 98.08635621] [167.33231027 182.30935108] [ 76.17240093 107.35183848]] Σ = [[[129.91266194 122.55330092] [122.55330092 157.18809759]] [[128.48897697 -2.56110994] [ -2.56110994 89.88336584]] [[ 32.24807564 10.55183143] [ 10.55183143 21.77544368]] [[ 12.97728354 21.06085578] [ 21.06085578 34.17970269]] [[ 48.57351974 -47.50497074] [-47.50497074 66.4268574 ]]]
Converged: True μ = [[ 89.56702654 96.96214489] [108.85758045 135.80404633] [ 46.48565538 58.62000464] [167.33231027 182.30935108] [ 84.24711907 113.91076227] [ 63.96599372 76.49051525]] Σ = [[[ 41.51715251 10.71301267] [ 10.71301267 15.05697359]] [[111.41907776 -51.33595092] [-51.33595092 65.93193515]] [[ 21.91210093 15.73401137] [ 15.73401137 94.1244966 ]] [[ 12.97728354 21.06085578] [ 21.06085578 34.17970269]] [[138.29095279 24.80731193] [ 24.80731193 55.39803817]] [[ 43.07651299 30.59317428] [ 30.59317428 39.89250389]]]
Converged: True μ = [[ 84.32936825 113.86389801] [ 42.8764746 46.46348825] [167.33231027 182.30935108] [ 63.8622448 76.45429854] [108.85532129 135.82998013] [ 47.61535049 63.88726266] [ 89.5738996 96.94283833]] Σ = [[[139.18330196 25.38556169] [ 25.38556169 56.3841127 ]] [[ 8.33015475 10.52160922] [ 10.52160922 13.2895829 ]] [[ 12.97728354 21.06085578] [ 21.06085578 34.17970269]] [[ 42.12667203 29.35114814] [ 29.35114814 37.80794388]] [[111.64792884 -51.44728558] [-51.44728558 65.79235137]] [[ 14.81864838 -18.7258475 ] [-18.7258475 25.96890988]] [[ 41.51043368 10.79668722] [ 10.79668722 15.00068556]]]
Converged: True μ = [[ 93.24334343 146.90250319] [ 64.11549144 75.98704934] [167.33231027 182.30935108] [ 43.22499252 57.60387426] [ 94.71620637 99.65745569] [114.15513818 131.674669 ] [ 79.89506285 102.80795361] [ 93.34629925 121.38560275]] Σ = [[[ 3.42841814e+01 -1.33918750e+00] [-1.33918750e+00 5.23115162e-02]] [[ 9.78110522e+01 9.45863978e+01] [ 9.45863978e+01 1.11572853e+02]] [[ 1.29772835e+01 2.10608558e+01] [ 2.10608558e+01 3.41797027e+01]] [[ 5.23960393e+00 9.71408767e+00] [ 9.71408767e+00 1.38577358e+02]] [[ 6.30060055e+00 7.19992871e-01] [ 7.19992871e-01 1.99254929e+01]] [[ 2.31893838e+01 1.26010633e+01] [ 1.26010633e+01 3.19483411e+01]] [[ 5.25756136e+01 -5.35908669e+01] [-5.35908669e+01 7.43316086e+01]] [[ 1.15884241e+01 -4.60530044e-01] [-4.60530044e-01 1.83027068e-02]]]
Converged: True μ = [[ 82.75546206 95.75084584] [114.1521719 131.71016241] [ 47.61585217 63.88365662] [167.33231027 182.30935108] [ 83.52391385 113.87696175] [ 94.7406069 97.75811647] [ 93.24334342 146.90250319] [ 42.87647464 46.46348831] [ 63.67018264 76.30063024]] Σ = [[[ 6.58784838e+00 3.98670637e+00] [ 3.98670637e+00 1.73137614e+01]] [[ 2.38674534e+01 1.28952887e+01] [ 1.28952887e+01 3.19865224e+01]] [[ 1.48059280e+01 -1.87001921e+01] [-1.87001921e+01 2.59269618e+01]] [[ 1.29772835e+01 2.10608558e+01] [ 2.10608558e+01 3.41797027e+01]] [[ 1.22466540e+02 1.63788760e+01] [ 1.63788760e+01 4.91595479e+01]] [[ 6.80931144e+00 5.32306185e+00] [ 5.32306185e+00 1.07725899e+01]] [[ 3.42841814e+01 -1.33918750e+00] [-1.33918750e+00 5.23115162e-02]] [[ 8.33015475e+00 1.05216092e+01] [ 1.05216092e+01 1.32895829e+01]] [[ 3.90035333e+01 2.67190573e+01] [ 2.67190573e+01 3.57833985e+01]]]
Converged: True μ = [[ 45.99568148 58.22016316] [114.15508407 131.67457936] [ 94.48310636 99.15835362] [167.33231027 182.30935108] [ 63.35946711 75.88069736] [ 93.34629925 121.38560275] [ 93.24334343 146.90250319] [ 78.17990242 104.90089538] [ 83.33653593 95.00582621] [ 65.62619579 119.45591518]] Σ = [[[ 1.93353354e+01 1.29335429e+01] [ 1.29335429e+01 9.61562675e+01]] [[ 2.31895399e+01 1.26014705e+01] [ 1.26014705e+01 3.19488431e+01]] [[ 6.51935576e+00 2.23716556e+00] [ 2.23716556e+00 2.29292405e+01]] [[ 1.29772835e+01 2.10608558e+01] [ 2.10608558e+01 3.41797027e+01]] [[ 4.20836721e+01 3.13697881e+01] [ 3.13697881e+01 4.24134593e+01]] [[ 1.15884241e+01 -4.60530044e-01] [-4.60530044e-01 1.83027068e-02]] [[ 3.42841814e+01 -1.33918750e+00] [-1.33918750e+00 5.23115162e-02]] [[ 1.03411888e+01 -1.22957510e+00] [-1.22957510e+00 1.63532312e+01]] [[ 5.74357830e+00 8.02369490e+00] [ 8.02369490e+00 1.77802175e+01]] [[ 1.00000000e-06 3.83701944e-26] [ 3.83701944e-26 1.00000000e-06]]]
Converged: True μ = [[ 63.75307696 76.32267133] [111.20888718 126.47421632] [167.33231027 182.30935108] [ 82.75014048 95.92790211] [ 46.44639345 58.57544917] [ 94.48554175 99.16765601] [ 93.24334343 146.90250319] [ 65.62619579 119.45591518] [117.10186372 136.87585252] [ 77.14743941 107.67442521] [ 93.34630094 121.38560268]] Σ = [[[ 3.95182565e+01 2.75765362e+01] [ 2.75765362e+01 3.75410022e+01]] [[ 1.13869704e+01 -3.71618435e+00] [-3.71618435e+00 6.92573728e+00]] [[ 1.29772835e+01 2.10608558e+01] [ 2.10608558e+01 3.41797027e+01]] [[ 6.71418691e+00 4.01219645e+00] [ 4.01219645e+00 1.76544331e+01]] [[ 2.14737841e+01 1.50822806e+01] [ 1.50822806e+01 9.33711813e+01]] [[ 6.49291020e+00 2.20468649e+00] [ 2.20468649e+00 2.29752908e+01]] [[ 3.42841814e+01 -1.33918750e+00] [-1.33918750e+00 5.23115162e-02]] [[ 1.00000000e-06 3.83701944e-26] [ 3.83701944e-26 1.00000000e-06]] [[ 1.76254028e+01 -1.73301560e+00] [-1.73301560e+00 2.87031288e+00]] [[ 1.14145843e+01 5.79287308e+00] [ 5.79287308e+00 2.93986991e+00]] [[ 1.15884241e+01 -4.60530044e-01] [-4.60530044e-01 1.83027068e-02]]]
Converged: True μ = [[117.10186372 136.87585252] [ 65.82219982 78.51158064] [ 94.51994579 99.25914612] [167.33231027 182.30935108] [ 43.92319029 68.21675726] [ 81.45472087 98.79229508] [ 93.24334343 146.90250319] [ 54.37426712 65.00109728] [ 93.34630095 121.38560268] [ 65.62619579 119.45591518] [111.20888718 126.47421632] [ 42.87640282 46.46339759]] Σ = [[[ 1.76254028e+01 -1.73301560e+00] [-1.73301560e+00 2.87031288e+00]] [[ 2.72581482e+01 1.10125065e+01] [ 1.10125065e+01 2.22472926e+01]] [[ 6.49202119e+00 1.95168259e+00] [ 1.95168259e+00 2.25669751e+01]] [[ 1.29772835e+01 2.10608558e+01] [ 2.10608558e+01 3.41797027e+01]] [[ 1.53620553e+00 -4.46252717e+00] [-4.46252717e+00 1.29632155e+01]] [[ 1.50228397e+01 -8.68676225e+00] [-8.68676225e+00 4.03127345e+01]] [[ 3.42841814e+01 -1.33918750e+00] [-1.33918750e+00 5.23115162e-02]] [[ 8.67262657e+00 1.40343349e+01] [ 1.40343349e+01 2.44443255e+01]] [[ 1.15884241e+01 -4.60530044e-01] [-4.60530044e-01 1.83027068e-02]] [[ 1.00000000e-06 4.22072139e-26] [ 4.22072139e-26 1.00000000e-06]] [[ 1.13869704e+01 -3.71618435e+00] [-3.71618435e+00 6.92573727e+00]] [[ 8.33015475e+00 1.05216092e+01] [ 1.05216092e+01 1.32895829e+01]]]
Converged: True μ = [[ 93.34630096 121.38560268] [ 65.37152306 74.98790518] [167.33231027 182.30935108] [ 94.5512191 99.33616174] [111.27132873 129.03207193] [ 42.87647471 46.4634884 ] [ 65.62619579 119.45591518] [ 81.75644864 99.75757302] [ 93.24334342 146.90250319] [ 73.20618789 84.52364094] [ 58.39675594 74.79714562] [ 48.93850298 64.39447382] [119.92287268 136.95989962]] Σ = [[[ 1.15884241e+01 -4.60530044e-01] [-4.60530044e-01 1.83027068e-02]] [[ 5.97359433e-01 2.08280687e+00] [ 2.08280687e+00 7.42400959e+00]] [[ 1.29772835e+01 2.10608558e+01] [ 2.10608558e+01 3.41797027e+01]] [[ 6.44954251e+00 1.73462057e+00] [ 1.73462057e+00 2.22294180e+01]] [[ 8.55254602e+00 -2.30801720e+00] [-2.30801720e+00 2.48279165e+01]] [[ 8.33015475e+00 1.05216092e+01] [ 1.05216092e+01 1.32895829e+01]] [[ 1.00000000e-06 3.83701944e-26] [ 3.83701944e-26 1.00000000e-06]] [[ 1.70842288e+01 -1.25731156e+01] [-1.25731156e+01 3.41564108e+01]] [[ 3.42841814e+01 -1.33918750e+00] [-1.33918750e+00 5.23115162e-02]] [[ 2.11286795e+01 7.62833891e+00] [ 7.62833891e+00 1.43507675e+01]] [[ 5.36942417e-01 3.21343875e+00] [ 3.21343875e+00 2.34046013e+01]] [[ 1.85535296e+01 -1.19516174e+01] [-1.19516174e+01 2.17436196e+01]] [[ 2.55875115e+00 -3.31079865e+00] [-3.31079865e+00 4.28388454e+00]]]
Converged: True μ = [[ 94.52165481 99.26206373] [ 51.35184044 59.46500355] [114.15508676 131.67458226] [163.72991071 176.4630102 ] [ 71.28962364 83.05083493] [ 93.34629926 121.38560275] [ 93.24334343 146.90250319] [ 65.62619579 119.45591518] [ 81.43985115 98.68549703] [ 65.37456661 74.99877534] [170.93470982 188.15569196] [ 42.87647481 46.46348852] [ 43.92319037 68.21675702] [ 57.54569809 73.16409471]] Σ = [[[ 6.48817815e+00 1.94273913e+00] [ 1.94273913e+00 2.25517227e+01]] [[ 3.37973049e-01 -1.14843296e-01] [-1.14843296e-01 3.90248859e-02]] [[ 2.31894855e+01 1.26014106e+01] [ 1.26014106e+01 3.19487807e+01]] [[ 1.00000000e-06 1.47018429e-25] [ 1.47018429e-25 1.00000000e-06]] [[ 1.01848602e+01 -3.37224890e+00] [-3.37224890e+00 8.30193447e+00]] [[ 1.15884241e+01 -4.60530044e-01] [-4.60530044e-01 1.83027068e-02]] [[ 3.42841814e+01 -1.33918750e+00] [-1.33918750e+00 5.23115162e-02]] [[ 1.00000000e-06 3.63507105e-26] [ 3.63507105e-26 1.00000000e-06]] [[ 1.49357458e+01 -8.42344726e+00] [-8.42344726e+00 4.08531237e+01]] [[ 5.98752877e-01 2.08805933e+00] [ 2.08805933e+00 7.44345255e+00]] [[ 1.00000000e-06 1.57519746e-25] [ 1.57519746e-25 1.00000000e-06]] [[ 8.33015475e+00 1.05216092e+01] [ 1.05216092e+01 1.32895829e+01]] [[ 1.53620553e+00 -4.46252717e+00] [-4.46252717e+00 1.29632155e+01]] [[ 3.33379496e+00 8.16144860e+00] [ 8.16144860e+00 2.94900625e+01]]]
Converged: True μ = [[ 49.28493743 61.18527731] [111.20888718 126.47421632] [ 82.69579035 95.80923444] [167.33231027 182.30935108] [ 61.61114821 74.93455151] [ 77.1474395 107.67442526] [ 93.24334343 146.90250319] [ 93.34630102 121.38560268] [ 94.74491826 97.77565841] [ 42.87647479 46.4634885 ] [117.10186372 136.87585252] [ 73.47705875 82.78564878] [ 42.68375319 71.81720344] [ 92.73509247 108.92179528] [ 65.62619579 119.45591518]] Σ = [[[ 8.74529672e+00 -7.16791633e+00] [-7.16791633e+00 5.92824546e+00]] [[ 1.13869704e+01 -3.71618435e+00] [-3.71618435e+00 6.92573728e+00]] [[ 6.71634476e+00 4.25220348e+00] [ 4.25220348e+00 1.80408207e+01]] [[ 1.29772835e+01 2.10608558e+01] [ 2.10608558e+01 3.41797027e+01]] [[ 1.87497442e+01 1.33347926e+01] [ 1.33347926e+01 2.69426070e+01]] [[ 1.14145843e+01 5.79287308e+00] [ 5.79287308e+00 2.93986991e+00]] [[ 3.42841814e+01 -1.33918750e+00] [-1.33918750e+00 5.23115162e-02]] [[ 1.15884241e+01 -4.60530044e-01] [-4.60530044e-01 1.83027068e-02]] [[ 6.83468918e+00 5.32542296e+00] [ 5.32542296e+00 1.07625306e+01]] [[ 8.33015475e+00 1.05216092e+01] [ 1.05216092e+01 1.32895829e+01]] [[ 1.76254028e+01 -1.73301560e+00] [-1.73301560e+00 2.87031288e+00]] [[ 9.15604663e-01 -3.34386421e+00] [-3.34386421e+00 1.22120840e+01]] [[ 1.00000000e-06 1.44393100e-26] [ 1.44393100e-26 1.00000000e-06]] [[ 1.00000000e-06 5.14968399e-26] [ 5.14968399e-26 1.00000000e-06]] [[ 1.00000000e-06 3.83701944e-26] [ 3.83701944e-26 1.00000000e-06]]]
Converged: True μ = [[117.10186372 136.87585252] [ 64.67096816 77.56685055] [ 94.48523338 99.16759563] [163.72991071 176.4630102 ] [ 42.87647479 46.4634885 ] [ 93.34630094 121.38560268] [ 77.14743951 107.67442526] [ 43.92319037 68.21675702] [ 93.24334343 146.90250319] [ 52.28058336 61.8455614 ] [ 82.69329156 95.8058391 ] [ 65.62619579 119.45591518] [111.20888718 126.47421632] [ 57.9855149 72.46311735] [ 73.4770411 82.78571327] [170.93470982 188.15569196]] Σ = [[[ 1.76254028e+01 -1.73301560e+00] [-1.73301560e+00 2.87031288e+00]] [[ 5.84978343e+00 -1.08438013e+00] [-1.08438013e+00 1.84049909e+01]] [[ 6.49387636e+00 2.20486408e+00] [ 2.20486408e+00 2.29769162e+01]] [[ 1.00000000e-06 1.47018429e-25] [ 1.47018429e-25 1.00000000e-06]] [[ 8.33015475e+00 1.05216092e+01] [ 1.05216092e+01 1.32895829e+01]] [[ 1.15884241e+01 -4.60530044e-01] [-4.60530044e-01 1.83027068e-02]] [[ 1.14145843e+01 5.79287308e+00] [ 5.79287308e+00 2.93986991e+00]] [[ 1.53620553e+00 -4.46252717e+00] [-4.46252717e+00 1.29632155e+01]] [[ 3.42841814e+01 -1.33918750e+00] [-1.33918750e+00 5.23115162e-02]] [[ 1.95058290e+00 4.34563354e+00] [ 4.34563354e+00 1.13609761e+01]] [[ 6.73096633e+00 4.27136825e+00] [ 4.27136825e+00 1.80679352e+01]] [[ 1.00000000e-06 3.83701944e-26] [ 3.83701944e-26 1.00000000e-06]] [[ 1.13869704e+01 -3.71618435e+00] [-3.71618435e+00 6.92573728e+00]] [[ 3.78917610e-02 4.41439477e-01] [ 4.41439477e-01 9.41211848e+00]] [[ 9.15604663e-01 -3.34386422e+00] [-3.34386422e+00 1.22120840e+01]] [[ 1.00000000e-06 1.57519746e-25] [ 1.57519746e-25 1.00000000e-06]]]
Converged: True μ = [[ 47.63814326 63.84067721] [111.27132874 129.03207192] [ 80.04625828 104.14641782] [170.93470982 188.15569196] [ 62.81420066 80.23511679] [ 96.06860377 105.51426707] [ 65.62619579 119.45591518] [ 99.09861288 146.67378827] [ 93.34630102 121.38560268] [ 60.80372801 71.36281678] [ 42.87647481 46.46348852] [119.92287268 136.95989962] [ 93.97832777 97.05128936] [ 73.47704201 82.78570992] [ 82.6777653 93.17611408] [ 87.38807398 147.13121811] [163.72991071 176.4630102 ]] Σ = [[[ 1.47354596e+01 -1.85462763e+01] [-1.85462763e+01 2.56556792e+01]] [[ 8.55254591e+00 -2.30801702e+00] [-2.30801702e+00 2.48279162e+01]] [[ 1.88892577e+01 -5.24444895e+00] [-5.24444895e+00 1.48628016e+01]] [[ 1.00000000e-06 1.57519746e-25] [ 1.57519746e-25 1.00000000e-06]] [[ 1.55704840e+01 5.20167016e+00] [ 5.20167016e+00 7.23572183e+00]] [[ 1.11130609e+01 -1.13598130e+01] [-1.13598130e+01 1.16120460e+01]] [[ 1.00000000e-06 3.83701944e-26] [ 3.83701944e-26 1.00000000e-06]] [[ 1.00000000e-06 6.66429693e-26] [ 6.66429693e-26 1.00000000e-06]] [[ 1.15884241e+01 -4.60530044e-01] [-4.60530044e-01 1.83027068e-02]] [[ 1.91802432e+01 1.14844680e+01] [ 1.14844680e+01 8.28763834e+00]] [[ 8.33015475e+00 1.05216092e+01] [ 1.05216092e+01 1.32895829e+01]] [[ 2.55875115e+00 -3.31079865e+00] [-3.31079865e+00 4.28388454e+00]] [[ 3.71427832e+00 2.28716226e+00] [ 2.28716226e+00 8.91259142e+00]] [[ 9.15604663e-01 -3.34386422e+00] [-3.34386422e+00 1.22120840e+01]] [[ 5.21866230e+00 4.48074311e+00] [ 4.48074311e+00 5.97738043e+00]] [[ 1.00000000e-06 6.22001047e-26] [ 6.22001047e-26 1.00000000e-06]] [[ 1.00000000e-06 1.47018429e-25] [ 1.47018429e-25 1.00000000e-06]]]
Converged: True μ = [[ 62.64777495 75.93828151] [ 93.34630102 121.38560268] [170.93470982 188.15569196] [ 94.74505824 97.77572569] [ 49.2887702 61.18210566] [117.10186372 136.87585252] [ 93.24334343 146.90250319] [ 65.62619579 119.45591518] [ 77.14743941 107.67442521] [ 82.72017386 95.86187833] [111.20888718 126.47421632] [ 69.91845066 81.93167372] [ 56.65071839 69.1109273 ] [163.72991071 176.4630102 ] [ 39.99027423 42.81800064] [ 42.68375319 71.81720344] [ 92.73509247 108.92179528] [ 45.76267538 50.1089764 ]] Σ = [[[ 9.56053384e+00 -6.34400084e+00] [-6.34400084e+00 1.16846103e+01]] [[ 1.15884241e+01 -4.60530044e-01] [-4.60530044e-01 1.83027068e-02]] [[ 1.00000000e-06 1.57519746e-25] [ 1.57519746e-25 1.00000000e-06]] [[ 6.83409686e+00 5.32518572e+00] [ 5.32518572e+00 1.07623267e+01]] [[ 8.73784226e+00 -7.16157191e+00] [-7.16157191e+00 5.92289868e+00]] [[ 1.76254028e+01 -1.73301560e+00] [-1.73301560e+00 2.87031288e+00]] [[ 3.42841814e+01 -1.33918750e+00] [-1.33918750e+00 5.23115162e-02]] [[ 1.00000000e-06 3.83701944e-26] [ 3.83701944e-26 1.00000000e-06]] [[ 1.14145843e+01 5.79287308e+00] [ 5.79287308e+00 2.93986991e+00]] [[ 6.71399803e+00 4.13815364e+00] [ 4.13815364e+00 1.78544883e+01]] [[ 1.13869704e+01 -3.71618435e+00] [-3.71618435e+00 6.92573728e+00]] [[ 1.55932746e+01 3.23105439e+00] [ 3.23105439e+00 1.15365223e+01]] [[ 3.19153124e+00 3.28036903e+00] [ 3.28036903e+00 3.73858009e+00]] [[ 1.00000000e-06 1.47018429e-25] [ 1.47018429e-25 1.00000000e-06]] [[ 1.00000000e-06 7.87598728e-27] [ 7.87598728e-27 1.00000000e-06]] [[ 1.00000000e-06 1.44393100e-26] [ 1.44393100e-26 1.00000000e-06]] [[ 1.00000000e-06 5.14968399e-26] [ 5.14968399e-26 1.00000000e-06]] [[ 1.00000000e-06 1.13091099e-26] [ 1.13091099e-26 1.00000000e-06]]]
Converged: True μ = [[ 92.73509247 108.92179528] [ 52.28058308 61.84556058] [170.93470982 188.15569196] [111.20888718 126.47421632] [ 79.51921237 89.17849171] [ 42.87647479 46.4634885 ] [ 65.62619579 119.45591518] [ 65.36890951 74.97831566] [ 87.38807398 147.13121811] [ 94.74711068 97.77822614] [ 77.14743915 107.67442508] [ 43.92319037 68.21675702] [117.10186372 136.87585252] [ 83.33415602 97.13181603] [ 68.3649216 82.71952256] [ 99.09861288 146.67378827] [ 93.34630102 121.38560268] [163.72991071 176.4630102 ] [ 57.98571366 72.46733794]] Σ = [[[ 1.00000000e-06 5.14968399e-26] [ 5.14968399e-26 1.00000000e-06]] [[ 1.95057811e+00 4.34562059e+00] [ 4.34562059e+00 1.13609408e+01]] [[ 1.00000000e-06 1.57519746e-25] [ 1.57519746e-25 1.00000000e-06]] [[ 1.13869704e+01 -3.71618435e+00] [-3.71618435e+00 6.92573728e+00]] [[ 1.00000000e-06 3.39273298e-26] [ 3.39273298e-26 1.00000000e-06]] [[ 8.33015475e+00 1.05216092e+01] [ 1.05216092e+01 1.32895829e+01]] [[ 1.00000000e-06 3.83701944e-26] [ 3.83701944e-26 1.00000000e-06]] [[ 5.95484929e-01 2.07579352e+00] [ 2.07579352e+00 7.39829469e+00]] [[ 1.00000000e-06 6.22001047e-26] [ 6.22001047e-26 1.00000000e-06]] [[ 6.82247011e+00 5.31617122e+00] [ 5.31617122e+00 1.07529653e+01]] [[ 1.14145843e+01 5.79287308e+00] [ 5.79287308e+00 2.93986991e+00]] [[ 1.53620553e+00 -4.46252717e+00] [-4.46252717e+00 1.29632155e+01]] [[ 1.76254028e+01 -1.73301560e+00] [-1.73301560e+00 2.87031288e+00]] [[ 5.65782390e+00 2.56789305e-02] [ 2.56789305e-02 1.11079423e+01]] [[ 3.29435905e+01 2.33811301e-01] [ 2.33811301e-01 6.56953593e+00]] [[ 1.00000000e-06 6.66429693e-26] [ 6.66429693e-26 1.00000000e-06]] [[ 1.15884241e+01 -4.60530044e-01] [-4.60530044e-01 1.83027068e-02]] [[ 1.00000000e-06 1.47018429e-25] [ 1.47018429e-25 1.00000000e-06]] [[ 3.78688607e-02 4.41478718e-01] [ 4.41478718e-01 9.41874857e+00]]]
best k : 19
dataset = pd.read_csv('/content/drive/MyDrive/penguins.csv')
dataset
| species | island | culmen_length_mm | culmen_depth_mm | flipper_length_mm | body_mass_g | sex | |
|---|---|---|---|---|---|---|---|
| 0 | Adelie | Torgersen | 39.1 | 18.7 | 181.0 | 3750.0 | MALE |
| 1 | Adelie | Torgersen | 39.5 | 17.4 | 186.0 | 3800.0 | FEMALE |
| 2 | Adelie | Torgersen | 40.3 | 18.0 | 195.0 | 3250.0 | FEMALE |
| 3 | Adelie | Torgersen | NaN | NaN | NaN | NaN | NaN |
| 4 | Adelie | Torgersen | 36.7 | 19.3 | 193.0 | 3450.0 | FEMALE |
| ... | ... | ... | ... | ... | ... | ... | ... |
| 339 | Gentoo | Biscoe | NaN | NaN | NaN | NaN | NaN |
| 340 | Gentoo | Biscoe | 46.8 | 14.3 | 215.0 | 4850.0 | FEMALE |
| 341 | Gentoo | Biscoe | 50.4 | 15.7 | 222.0 | 5750.0 | MALE |
| 342 | Gentoo | Biscoe | 45.2 | 14.8 | 212.0 | 5200.0 | FEMALE |
| 343 | Gentoo | Biscoe | 49.9 | 16.1 | 213.0 | 5400.0 | MALE |
344 rows × 7 columns
dataset.isna().sum()
species 0 island 0 culmen_length_mm 2 culmen_depth_mm 2 flipper_length_mm 2 body_mass_g 2 sex 10 dtype: int64
dataset['sex'].value_counts()
MALE 168 FEMALE 165 . 1 Name: sex, dtype: int64
dataset['species'].unique()
array(['Adelie', 'Chinstrap', 'Gentoo'], dtype=object)
dataset['culmen_depth_mm'].fillna(dataset['culmen_depth_mm'].mean(),inplace=True)
dataset['culmen_length_mm'].fillna(dataset['culmen_length_mm'].mean(),inplace=True)
dataset['flipper_length_mm'].fillna(dataset['flipper_length_mm'].mean(),inplace=True)
dataset['body_mass_g'].fillna(dataset['body_mass_g'].mean(),inplace=True)
dataset['sex'].fillna('MALE',inplace=True)
dataset['sex'].replace({'.':'MALE'},inplace=True)
dataset.isna().sum()
species 0 island 0 culmen_length_mm 0 culmen_depth_mm 0 flipper_length_mm 0 body_mass_g 0 sex 0 dtype: int64
dataset.island.value_counts()
Biscoe 168 Dream 124 Torgersen 52 Name: island, dtype: int64
dataset.replace({"island": {"Torgersen": 1, "Dream": 2, "Biscoe": 3},"sex": {"MALE": 1, "FEMALE": 2}},inplace=True)
def draw_scatter_plots(X1,X2,X3,features1,features2):
for f1,f2 in zip(features1,features2):
plt.scatter(X1[f1], X1[f2], label = "Adelie")
plt.scatter(X2[f1], X2[f2], label = "Gentoo")
plt.scatter(X3[f1], X3[f2], label = "Chinstrap")
plt.xlabel(f1)
plt.ylabel(f2)
plt.title(f1 + " - " + f2 + " scatter plot")
plt.legend()
plt.show()
features1 = ['culmen_length_mm','flipper_length_mm','body_mass_g','flipper_length_mm']
features2 = ['culmen_depth_mm','culmen_length_mm','flipper_length_mm','culmen_depth_mm']
X1 = dataset[dataset['species'] == 'Adelie']
X2 = dataset[dataset['species'] == 'Gentoo']
X3 = dataset[dataset['species'] == 'Chinstrap']
draw_scatter_plots(X1,X2,X3,features1,features2)
def gmm_model(X,features1,features2):
means = []
covariances = []
for f1,f2 in zip(features1,features2):
print('FOR '+f1+' AND '+f2+' :')
data = X[[f1,f2]].to_numpy()
GMM = GaussianMixture(n_components=1, covariance_type='full').fit(data)
print('Converged:',GMM.converged_)
means.append(GMM.means_)
covariances.append(GMM.covariances_)
print('\u03BC = ', GMM.means_, sep="\n")
print('\u03A3 = ', GMM.covariances_, sep="\n")
return means,covariances
def calculate_error(X,y,feature1,feature2,means,covariances,features_ind):
data = X[[feature1,feature2]].to_numpy()
false = 0
true = 0
for sample_ind in range(len(data)):
pred = -1
max_prob = 0
for i in range(3):
multi_normal = multivariate_normal(mean=means[i][features_ind][0],cov=covariances[i][features_ind][0])
prob = multi_normal.pdf(data[sample_ind])
pred = i if prob > max_prob else pred
max_prob = prob if prob > max_prob else max_prob
if y[sample_ind] != pred :
false += 1
else:
true += 1
error = false/len(X)
accuracy = true/len(X)
return error, accuracy
def draw_gmm_plot(X,labels,features1,features2,means,covariances):
j = 0
X1 = dataset[dataset['species'] == 'Adelie']
X2 = dataset[dataset['species'] == 'Gentoo']
X3 = dataset[dataset['species'] == 'Chinstrap']
for f1,f2 in zip(features1,features2):
plt.scatter(X1[f1], X1[f2], label = "Adelie")
plt.scatter(X2[f1], X2[f2], label = "Gentoo")
plt.scatter(X3[f1], X3[f2], label = "Chinstrap")
plt.xlabel(f1)
plt.ylabel(f2)
plt.title(f1 + " - " + f2 + " scatter plot")
error, accuracy = calculate_error(X,labels,f1,f2,means,covariances,j)
i = 0
for d in [X1,X2,X3]:
data = d[[f1,f2]].to_numpy()
x,y = np.meshgrid(np.sort(data[:,0]),np.sort(data[:,1]))
XY = np.array([x.flatten(),y.flatten()]).T
multi_normal = multivariate_normal(mean=means[i][j][0],cov=covariances[i][j][0])
plt.contour(np.sort(data[:,0]),np.sort(data[:,1]),multi_normal.pdf(XY).reshape(len(data),len(data)),alpha=0.3)
plt.scatter(means[i][j][0][0],means[i][j][0][1],c='red',zorder=10,s=100)
i += 1
j += 1
plt.legend()
plt.show()
print('error = ', error)
print('accuracy = ', accuracy)
y = dataset['species'].replace({'Adelie':0,'Gentoo':1,'Chinstrap':2})
x1_means, x1_covariances = gmm_model(X1,features1,features2)
FOR culmen_length_mm AND culmen_depth_mm : Converged: True μ = [[38.82514428 18.33849454]] Σ = [[[7.17242248 1.21183321] [1.21183321 1.47009688]]] FOR flipper_length_mm AND culmen_length_mm : Converged: True μ = [[190.02575793 38.82514428]] Σ = [[[42.98711183 5.96716097] [ 5.96716097 7.17242248]]] FOR body_mass_g AND flipper_length_mm : Converged: True μ = [[3703.95891043 190.02575793]] Σ = [[[2.09157074e+05 1.42145560e+03] [1.42145560e+03 4.29871118e+01]]] FOR flipper_length_mm AND culmen_depth_mm : Converged: True μ = [[190.02575793 18.33849454]] Σ = [[[42.98711183 2.3296682 ] [ 2.3296682 1.47009688]]]
x2_means, x2_covariances = gmm_model(X2,features1,features2)
FOR culmen_length_mm AND culmen_depth_mm : Converged: True μ = [[47.47598331 14.99960621]] Σ = [[[9.44734828 1.85203058] [1.85203058 0.98490029]]] FOR flipper_length_mm AND culmen_length_mm : Converged: True μ = [[217.05576778 47.47598331]] Σ = [[[43.49463806 13.46708554] [13.46708554 9.44734828]]] FOR body_mass_g AND flipper_length_mm : Converged: True μ = [[5068.96576118 217.05576778]] Σ = [[[2.56148531e+05 2.37389278e+03] [2.37389278e+03 4.34946381e+01]]] FOR flipper_length_mm AND culmen_depth_mm : Converged: True μ = [[217.05576778 14.99960621]] Σ = [[[43.49463806 4.14114194] [ 4.14114194 0.98490029]]]
x3_means, x3_covariances = gmm_model(X3,features1,features2)
FOR culmen_length_mm AND culmen_depth_mm : Converged: True μ = [[48.83382353 18.42058824]] Σ = [[[10.98665109 2.44136246] [ 2.44136246 1.27016536]]] FOR flipper_length_mm AND culmen_length_mm : Converged: True μ = [[195.82352941 48.83382353]] Σ = [[[50.11591796 11.06626298] [11.06626298 10.98665109]]] FOR body_mass_g AND flipper_length_mm : Converged: True μ = [[3733.08823529 195.82352941]] Σ = [[[1.45541198e+05 1.73267734e+03] [1.73267734e+03 5.01159180e+01]]] FOR flipper_length_mm AND culmen_depth_mm : Converged: True μ = [[195.82352941 18.42058824]] Σ = [[[50.11591796 4.62863322] [ 4.62863322 1.27016536]]]
means = [x1_means,x2_means,x3_means]
covariances = [x1_covariances,x2_covariances,x3_covariances]
draw_gmm_plot(dataset,y,features1,features2,means,covariances)
error = 0.03197674418604651 accuracy = 0.9680232558139535
error = 0.04941860465116279 accuracy = 0.9505813953488372
error = 0.19476744186046513 accuracy = 0.8052325581395349
error = 0.20348837209302326 accuracy = 0.7965116279069767
def aic_bic(X,f1,f2,species):
AIC = []
BIC = []
K = [2,3,4,5]
for k in K:
data = X[[f1,f2]].to_numpy()
GMM = GaussianMixture(n_components=k, covariance_type='full').fit(data)
AIC.append(GMM.aic(data))
BIC.append(GMM.bic(data))
plt.plot(K, AIC, 'b', label='aic')
plt.plot(K, BIC, 'r', label='bic')
plt.title(species)
plt.show()
sum_ab = [i + j for i,j in zip(AIC,BIC)]
best_k = K[sum_ab.index(min(sum_ab))]
print("best k : ", best_k)
f1 = features1[0]
f2 = features2[0]
aic_bic(X1,f1,f2,'Adelie')
aic_bic(X2,f1,f2,'Gentoo')
aic_bic(X3,f1,f2,'Chinstrap')
best k : 2
best k : 2
best k : 5
from sklearn import datasets
noisy_moons=datasets.make_moons(n_samples=500, noise=0.11)
noisy_moons
(array([[ 2.06941243e+00, 4.58227287e-01],
[-7.80188965e-01, 7.16853733e-01],
[ 1.48218903e-01, 9.67525082e-01],
[ 2.36978827e-01, 9.71785410e-01],
[ 1.50553559e+00, -4.80087815e-01],
[-1.10412344e+00, 1.31304495e-01],
[ 1.39791814e+00, -4.47609262e-01],
[ 1.80751322e+00, -6.41427035e-02],
[ 1.47036495e+00, -2.76418522e-01],
[ 7.21410000e-02, 1.47867702e-01],
[ 1.27602759e-01, 1.00202121e+00],
[-1.10399451e+00, 4.07427952e-01],
[ 3.72281535e-02, 1.02728809e+00],
[-8.58302452e-02, 1.19324724e+00],
[ 9.49274804e-01, 2.79376152e-01],
[-8.12820650e-01, 6.62307717e-01],
[ 8.06501143e-01, 7.23851693e-01],
[ 7.95528776e-01, 7.09829151e-01],
[-1.05058184e+00, 1.50173886e-01],
[ 5.68918773e-01, -1.38452784e-01],
[-7.96815945e-01, 5.63939981e-01],
[ 9.40809639e-01, -4.32324563e-01],
[ 1.52320676e-01, 1.02589201e+00],
[-4.67337280e-01, 6.75673810e-01],
[-3.29312393e-01, 9.37756483e-01],
[ 3.27731594e-02, 1.18196381e-01],
[ 1.94784107e+00, 1.57708712e-01],
[ 3.29101658e-01, -3.68483627e-01],
[ 6.46153104e-01, -6.57487068e-01],
[ 3.82562888e-01, 1.30741687e-01],
[ 7.55435016e-01, -3.98684223e-01],
[ 1.00394805e+00, 3.12039266e-01],
[ 3.52825240e-02, 2.97197781e-01],
[ 2.59125823e-01, -1.32295483e-01],
[ 1.30426910e+00, -4.86624758e-01],
[-2.22980856e-02, 8.42261977e-01],
[ 1.61437789e+00, -9.84196215e-02],
[-5.18784494e-01, 1.00828233e+00],
[-4.03614883e-01, 9.83636128e-01],
[-2.03757095e-01, 3.45830592e-01],
[-5.75487987e-01, 6.83701721e-01],
[-5.95412805e-01, 4.62859268e-01],
[ 8.09502640e-01, -4.95675413e-01],
[ 7.88367416e-01, 5.79128482e-01],
[-1.04982601e+00, 4.62956902e-01],
[ 7.76094711e-02, 3.74810638e-01],
[ 5.31272657e-01, 7.57645779e-01],
[ 9.85690347e-01, 3.39861867e-01],
[ 1.75640542e+00, 2.20575905e-01],
[ 9.65925397e-01, -3.71870957e-01],
[ 7.91917624e-01, 1.62886904e-01],
[ 1.22312506e-01, 2.37632531e-01],
[ 5.39291602e-01, -2.48266263e-01],
[-8.64075674e-01, 2.38997750e-01],
[-5.01094293e-01, 7.68368806e-01],
[ 1.31379159e-02, 1.06426125e+00],
[-1.08658090e-01, 7.84233793e-01],
[ 1.07799645e+00, -9.70524132e-02],
[ 1.59413849e+00, -2.75525202e-01],
[ 1.90973742e+00, -6.81001478e-02],
[ 1.83630285e-01, -3.67262177e-02],
[ 2.36870659e-01, -4.97981147e-02],
[-7.07778053e-01, 5.83823302e-01],
[ 1.39022768e+00, -6.14871981e-01],
[ 6.09089542e-01, -4.27775142e-01],
[-1.03951871e+00, 2.21682865e-01],
[ 7.52095348e-01, 6.06566044e-01],
[ 1.53787274e+00, -2.85344973e-01],
[ 2.13363873e-01, 2.20921581e-01],
[ 1.16277080e+00, -6.88744774e-01],
[ 1.82401389e-02, 1.01369522e+00],
[-3.98516366e-02, 3.61091838e-01],
[ 1.58708276e+00, -2.69412906e-01],
[-1.26695154e-01, 9.22730412e-01],
[-4.99383486e-01, 1.00294376e+00],
[ 6.58151390e-01, -5.21768278e-01],
[ 2.01590067e+00, 2.35321567e-01],
[-1.07837712e+00, 5.60323649e-02],
[-4.19704463e-01, 9.09230010e-01],
[ 5.31695841e-01, 1.02252198e+00],
[-8.84299737e-01, -7.47940933e-03],
[ 3.88225936e-01, 1.00773482e+00],
[ 8.55673309e-01, 2.77236822e-01],
[ 2.92608299e-01, 1.08502683e+00],
[ 2.33128048e-01, 9.03198286e-02],
[ 4.04829388e-01, 1.05495083e+00],
[-8.42628493e-01, 5.33391286e-01],
[ 1.61219764e-01, 1.58263753e-02],
[ 6.91837485e-01, 5.91271333e-01],
[ 6.60924790e-01, -3.26738507e-01],
[ 1.39434693e+00, -4.75483246e-01],
[-1.73402668e-01, 9.43588670e-01],
[ 1.77008489e+00, -3.01461382e-01],
[ 5.52623413e-01, 9.53619115e-01],
[ 9.96829783e-01, -5.71016160e-01],
[ 1.16255132e+00, -5.42965014e-01],
[ 2.00159480e+00, 4.66810984e-01],
[-7.58052026e-01, 8.38330575e-01],
[ 1.32781119e+00, -5.65886684e-01],
[-9.31638635e-01, 6.10498232e-01],
[-2.41888092e-01, 9.60071688e-01],
[ 8.33954482e-01, 3.19655332e-01],
[ 4.66164032e-01, -4.06190954e-01],
[ 1.09811107e+00, -3.37639417e-01],
[ 3.41266009e-01, 7.88520686e-01],
[ 1.02449315e+00, -7.25880552e-02],
[ 9.57265505e-01, 1.23790523e-01],
[ 1.01319168e+00, 5.11112239e-02],
[-8.76476340e-01, -1.00658627e-01],
[ 9.08449563e-01, 3.21688376e-01],
[-9.53801128e-01, 5.13797241e-01],
[ 1.02817989e+00, 1.08527531e-01],
[-3.43838762e-01, 8.96765309e-01],
[ 1.54924222e+00, -2.72237485e-01],
[ 1.45478412e+00, -3.19293663e-01],
[ 7.87133695e-01, -2.86180835e-01],
[ 1.54533191e+00, -2.43038985e-01],
[ 1.23938691e-01, 1.96769129e-01],
[ 6.33652379e-01, -4.42441805e-01],
[ 1.55403765e+00, -3.40442289e-01],
[ 1.06424363e+00, 4.79315413e-01],
[ 1.18950481e+00, -4.19217819e-01],
[ 8.79738279e-01, -2.04159400e-01],
[ 1.18319536e+00, -4.69761344e-01],
[ 3.08532497e-01, 7.84751859e-01],
[ 1.30328388e+00, -2.19294550e-01],
[ 8.37112676e-01, 5.53253441e-01],
[ 1.09514253e+00, 8.89575757e-02],
[-8.93138179e-01, 3.76461842e-01],
[ 7.70758709e-01, 7.84165945e-01],
[ 1.44410960e+00, -2.76689231e-01],
[ 9.25215218e-01, 1.88646704e-01],
[ 1.91490123e+00, -2.88515622e-02],
[ 1.56784604e+00, -2.80307762e-01],
[ 2.02030976e+00, 4.99938914e-01],
[ 1.02906824e+00, 8.93766388e-02],
[ 1.89024505e+00, 3.12708088e-01],
[ 2.94510913e-01, 9.20425076e-01],
[ 8.64982718e-01, -6.94075323e-01],
[ 8.08297375e-01, 6.26441513e-01],
[-8.80713884e-01, 2.78380406e-04],
[ 8.41060193e-01, 3.85091823e-01],
[-1.04095503e-01, 9.63145631e-01],
[ 1.74701728e+00, 2.57236771e-01],
[ 1.88401188e+00, 2.07122735e-01],
[-9.77370855e-01, -3.75932624e-02],
[-1.18062881e-01, 9.01058436e-01],
[ 9.62467041e-01, 1.97418193e-01],
[ 5.21806753e-01, -3.03387366e-01],
[-6.17743227e-01, 7.94279291e-01],
[-9.02940357e-01, 4.45227003e-01],
[-4.18722478e-01, 9.02281748e-01],
[ 5.30874450e-02, -2.87287776e-02],
[ 1.75582954e+00, -2.10228223e-01],
[ 9.03910067e-01, -4.23444019e-01],
[ 8.12783336e-01, 6.83020171e-01],
[ 2.29057745e-01, -4.68703039e-01],
[ 1.13407781e-01, 3.05712044e-01],
[ 1.22675918e-01, 3.04130952e-01],
[ 5.28704518e-01, -4.66465624e-01],
[-4.13040931e-01, 7.21672349e-01],
[-1.16201723e-01, 1.23627723e+00],
[ 4.08941932e-01, 7.37943288e-01],
[ 9.59834818e-01, 1.57389364e-01],
[ 1.80004676e+00, -1.08078152e-01],
[ 1.27225503e+00, -4.24024071e-01],
[-9.87654305e-01, 8.82196389e-02],
[ 7.57039227e-01, 5.42106507e-01],
[ 1.98563041e+00, 1.59336177e-01],
[-8.30731275e-02, -6.29156030e-02],
[ 2.61637042e-01, -2.04403122e-01],
[ 8.63332150e-01, 5.43692254e-01],
[ 1.84792108e-01, 3.10255428e-01],
[ 1.92028918e+00, 2.09527311e-01],
[-1.25350675e-01, 1.02764963e+00],
[ 9.88136023e-02, 1.11132483e+00],
[ 1.40389562e+00, -3.43898672e-01],
[ 2.14509960e+00, 2.95500390e-01],
[-1.04937396e+00, 5.03864682e-01],
[ 1.87917394e+00, 1.68458916e-02],
[ 8.28505093e-01, -5.36051143e-01],
[ 2.10564772e+00, 3.83943573e-01],
[ 7.91903058e-01, 6.01980530e-01],
[ 1.98214223e+00, 1.87939838e-01],
[ 9.57304417e-01, 1.23275003e-01],
[ 5.41921705e-02, 1.84151214e-02],
[ 1.63413519e+00, -4.17699291e-01],
[ 9.19929694e-01, -2.72976644e-01],
[ 1.29064095e-01, 1.43833871e-01],
[ 7.73084718e-01, -3.40320143e-01],
[ 6.23942264e-01, 8.51834681e-01],
[ 1.98512365e-01, 9.02542850e-01],
[ 3.44189982e-02, 9.37724390e-02],
[ 4.45133845e-01, -4.27545847e-02],
[ 6.99548953e-02, 2.30550162e-01],
[ 3.26454432e-01, -3.67126641e-01],
[ 1.12227422e+00, 1.31032456e-02],
[ 1.99879719e+00, 6.11194082e-01],
[ 2.13320681e+00, 5.05057767e-01],
[ 6.38487766e-02, 2.65247160e-02],
[-9.96096418e-01, 1.07757239e-01],
[-1.25439980e-01, 1.03507219e+00],
[ 6.27756134e-01, 6.25711633e-01],
[-9.89482315e-01, 4.02796431e-01],
[ 9.20011077e-01, 2.04323746e-01],
[-1.08205166e+00, 7.88278236e-01],
[ 7.25510481e-01, 6.71157133e-01],
[ 5.53385378e-01, 6.12617733e-01],
[ 1.06513040e+00, 3.28524391e-01],
[ 1.72099906e+00, -3.29251377e-01],
[-5.76784098e-01, 7.61513338e-01],
[-4.16921194e-02, 3.30571802e-01],
[ 1.51396168e-01, -3.46990229e-03],
[ 7.08880217e-01, -6.01575351e-01],
[ 8.16298193e-01, 5.34635058e-01],
[ 1.02034670e+00, 1.35766073e-01],
[ 5.13939513e-01, -3.27180440e-01],
[ 6.72406864e-02, -3.08963462e-01],
[-9.20920187e-01, 2.82553709e-01],
[-8.22157492e-01, 6.04041186e-01],
[-4.79948838e-01, 6.63972637e-01],
[ 1.41893327e+00, -4.06311107e-01],
[ 1.03382406e-01, 8.79569292e-01],
[ 2.30233055e-01, -2.33491390e-01],
[ 1.91790030e+00, 8.28435633e-02],
[ 7.84056036e-01, 3.03010028e-01],
[-4.20594117e-02, 1.00932589e+00],
[ 3.21183441e-01, 8.18808095e-01],
[ 7.47387442e-01, -3.67880395e-01],
[-1.38332423e-01, 2.06574550e-01],
[-9.56785030e-01, 3.88283752e-01],
[ 4.85979840e-01, 1.08281102e+00],
[ 1.73981956e+00, -2.36395038e-01],
[-3.20207359e-02, 3.80635306e-01],
[ 1.72475353e+00, -1.63436191e-01],
[ 1.17133450e+00, -4.92153013e-01],
[ 1.99525454e+00, 6.84235159e-01],
[-7.36261742e-01, 5.85422129e-01],
[ 7.97531282e-01, 3.97434924e-01],
[-5.83943106e-01, 7.22557833e-01],
[ 1.74203575e+00, -3.21865912e-02],
[ 3.34237444e-01, -1.38559859e-02],
[ 6.41023401e-01, 7.94987540e-01],
[ 1.82495318e+00, 3.85670687e-01],
[ 9.99632220e-02, 1.46084971e-01],
[-3.30494056e-01, 7.69092976e-01],
[-9.62102288e-01, 3.05768502e-01],
[ 1.16155631e-01, 3.61984500e-01],
[ 3.81062894e-01, -3.40482730e-01],
[ 1.76269721e+00, -1.79985542e-01],
[-7.59278308e-01, 5.80992203e-01],
[-4.70756171e-01, 7.35036389e-01],
[ 8.67720001e-01, 5.24407235e-01],
[ 7.69714487e-03, 2.84016989e-01],
[-3.17428200e-01, 6.70453145e-01],
[ 2.01656618e+00, 3.31024169e-01],
[ 3.54793191e-01, 8.28974963e-01],
[-7.63238148e-01, 5.82058372e-01],
[ 2.57122058e-01, -2.62246283e-01],
[ 2.90299260e-01, -3.34844961e-01],
[ 9.97439379e-01, 3.43446773e-01],
[ 2.14081523e+00, 3.52989771e-01],
[ 1.48093747e+00, -2.88848395e-01],
[-1.23581619e-01, 5.13156893e-02],
[-8.22558078e-01, -5.78950166e-02],
[-9.54009287e-01, 4.69999102e-01],
[-9.95456111e-01, 1.59981669e-01],
[ 8.03639563e-01, -5.51748405e-01],
[ 1.46259361e+00, -1.78442474e-01],
[-1.04917627e-01, 3.97620552e-01],
[ 2.10173398e-01, -6.78564209e-02],
[ 1.49507575e-01, -1.19766482e-01],
[-1.57280633e-01, 1.02463752e+00],
[-1.00995411e-01, 4.00463133e-01],
[-9.60085665e-01, 2.95712336e-01],
[-1.06009842e+00, 1.78272375e-01],
[ 8.72773539e-01, 5.90164747e-01],
[-4.13478118e-01, 1.18104195e+00],
[ 1.91459904e+00, 7.13607599e-02],
[ 7.52100284e-01, -3.08542595e-01],
[-9.22053151e-01, 5.02246319e-02],
[ 1.42225741e+00, -2.61439031e-01],
[ 1.72696996e-01, 1.82327836e-01],
[ 4.52363373e-01, -2.70101452e-01],
[ 1.99742765e+00, -2.78826432e-01],
[ 1.18742990e+00, 2.41348572e-01],
[-8.01765777e-01, 8.23957243e-01],
[ 7.39136114e-01, -4.55651419e-01],
[ 1.81680911e+00, -1.35595040e-01],
[ 1.82261552e+00, 4.43201736e-01],
[ 1.73094849e-01, 8.21591054e-01],
[-5.54166660e-01, 7.76465834e-01],
[ 1.81340500e-01, -8.06131579e-02],
[-1.99448554e-03, 1.01783459e+00],
[ 1.48260165e-01, 1.01743513e+00],
[-5.40771418e-02, 9.19338909e-01],
[-5.28379027e-01, 1.07220750e+00],
[ 9.90485360e-01, 3.55600902e-01],
[-5.70162136e-01, 7.14828763e-01],
[ 1.62657731e+00, 6.12692155e-02],
[ 3.11374920e-02, 4.01818127e-01],
[ 6.36139807e-01, 9.65979488e-01],
[ 1.89250557e+00, -4.74318463e-02],
[ 9.37063344e-01, 2.77721721e-01],
[ 3.64944189e-01, 8.05124360e-01],
[ 1.88952634e+00, 9.08023472e-02],
[ 2.13198631e+00, 5.62754822e-01],
[ 2.85083762e-01, 9.94546179e-01],
[-8.56510342e-01, 6.35588172e-01],
[ 7.14069706e-01, 7.62984167e-01],
[-9.57392914e-01, 1.38479775e-01],
[-9.42700895e-01, 2.61432683e-01],
[ 5.06772660e-02, 8.89528339e-02],
[ 1.68706043e+00, -3.38084329e-01],
[ 7.12824970e-01, 8.44919702e-01],
[-1.43696503e-01, 1.01211445e+00],
[ 1.26771354e+00, -7.73553683e-01],
[ 1.82945296e-01, 1.96505868e-01],
[ 1.37191279e+00, -3.20090165e-01],
[-1.03319411e+00, 1.57465093e-01],
[ 6.90693404e-01, 7.37916931e-01],
[ 3.26348635e-01, -2.26323865e-01],
[ 1.76519957e+00, -2.82644640e-01],
[ 1.05521049e+00, -6.64910714e-01],
[ 2.00576903e+00, 1.79039053e-01],
[-3.38477112e-01, 9.07262279e-01],
[ 1.98106066e+00, 1.84058004e-01],
[ 7.91893412e-01, 5.35693570e-01],
[ 2.22404531e-01, 9.61770925e-02],
[ 8.93191695e-01, 5.75976838e-01],
[-1.00255364e+00, 4.36603082e-01],
[-9.15091593e-01, 7.14290568e-01],
[ 1.25860005e+00, -3.94735257e-01],
[ 8.41826593e-01, -6.52254847e-01],
[ 4.99953927e-01, -1.87317358e-01],
[ 8.58076942e-01, -5.78154865e-01],
[ 1.98655456e+00, -2.00415808e-01],
[ 3.30473697e-01, -4.74259220e-01],
[ 1.76693048e+00, -1.27931492e-01],
[ 1.27067371e-01, 9.81263315e-01],
[ 7.42570969e-01, -2.20284894e-01],
[ 1.14423157e+00, -5.27269386e-01],
[ 1.75718540e+00, -1.18020919e-01],
[-2.91195027e-01, 9.33989133e-01],
[ 1.06434250e+00, -3.58051747e-01],
[ 2.54373160e-01, 9.02096463e-01],
[ 1.87826179e+00, 1.99059399e-01],
[ 9.49705822e-01, 6.68314887e-01],
[ 8.13844420e-01, 5.52339121e-01],
[-2.62390082e-01, 9.55020813e-01],
[ 8.92150080e-01, -4.48033469e-01],
[-1.00184368e+00, 3.24962208e-01],
[-8.16828176e-01, 8.19289011e-01],
[ 7.00995016e-01, -4.35016088e-01],
[ 2.13814120e+00, 2.33978243e-01],
[ 1.27792828e+00, -3.42371593e-01],
[ 2.86780247e-01, -4.13063596e-01],
[ 1.31975630e+00, -3.57243412e-01],
[ 1.72187833e+00, -6.11314322e-02],
[ 4.37355438e-01, 1.07492446e+00],
[-9.52423234e-01, 3.39405545e-01],
[-8.58370329e-01, 2.18693362e-01],
[ 1.24133433e+00, -5.31004501e-01],
[ 2.28580075e-01, -5.61107134e-02],
[ 2.78459691e-01, -2.51150609e-01],
[ 9.51457099e-01, 3.27490156e-01],
[ 4.08187346e-01, -4.22755424e-01],
[ 5.08223093e-01, 9.02084858e-01],
[ 6.26531595e-01, 7.22578295e-01],
[ 7.55068716e-01, 3.25932811e-01],
[ 1.99615480e+00, 2.95160960e-01],
[-4.27587479e-01, 8.90406522e-01],
[ 6.14573477e-01, 5.99275459e-01],
[-4.62472294e-01, 8.73266943e-01],
[ 6.37462329e-01, 8.55101533e-01],
[-1.09792962e+00, 3.82168065e-01],
[ 1.23926484e+00, -4.80652189e-01],
[ 4.72058070e-01, -4.38599362e-01],
[ 1.86715616e+00, -1.42956754e-02],
[ 2.38577270e-01, -2.13086873e-02],
[-2.06850531e-03, 1.38626610e-01],
[ 9.47432442e-01, 4.84861414e-01],
[ 9.70997866e-01, 5.45180935e-01],
[-3.57831056e-01, 8.72904221e-01],
[ 1.62233480e+00, -4.47795838e-01],
[ 7.27575807e-01, 7.40212317e-01],
[ 1.36064980e-01, 6.96386606e-02],
[ 1.51345943e+00, -3.37549839e-01],
[-5.84562924e-01, 8.16680621e-01],
[ 1.00437314e+00, 8.78820887e-02],
[-6.20222331e-02, 5.39954729e-01],
[-6.89101145e-01, 6.67558233e-01],
[ 6.22845254e-01, 9.08156448e-01],
[-2.53004044e-02, 9.10546127e-01],
[-2.04304822e-01, 9.97291190e-01],
[-2.44133494e-01, 9.47681186e-01],
[ 6.48232826e-01, 8.08130363e-01],
[-2.94588842e-01, 8.58699280e-01],
[ 1.58770016e+00, -4.40219199e-01],
[ 5.82801334e-02, 5.15416105e-01],
[ 2.02452554e+00, 9.70289653e-02],
[-1.50803157e-01, 5.56358692e-01],
[ 1.12886314e+00, -7.73122218e-01],
[ 9.05019853e-01, 1.61728957e-01],
[ 7.03206302e-01, -3.71516302e-01],
[ 3.72939261e-01, 9.09046232e-01],
[ 1.96870153e-01, 5.75149298e-02],
[-2.03473360e-01, 1.02925319e+00],
[ 4.37201212e-01, -3.56056246e-01],
[ 1.39518290e+00, -6.16944698e-01],
[ 1.69315516e+00, -2.87452577e-01],
[ 7.68432892e-01, 1.87394751e-01],
[ 6.49017393e-01, 8.85886953e-01],
[-6.15635392e-01, 6.65823920e-01],
[ 1.08743863e+00, -5.47718066e-01],
[-8.39714279e-01, 2.22374130e-01],
[ 1.58194312e-01, 1.06342042e+00],
[ 9.28147026e-01, -1.53645852e-02],
[-3.11796752e-01, 9.00799983e-01],
[ 1.02815005e+00, -5.45767289e-01],
[ 4.41465981e-01, 9.99158053e-01],
[ 8.59374913e-01, -5.98578861e-01],
[-1.47308866e-01, 1.05119743e+00],
[-9.72944529e-01, 4.84935022e-01],
[ 1.74332652e+00, -3.52030059e-01],
[-7.95721709e-01, 4.19936302e-01],
[ 1.23839351e+00, -5.76951286e-01],
[ 3.04316500e-01, 1.03022739e+00],
[-1.08471517e+00, 1.23771199e-01],
[-5.39456607e-02, 1.21419424e+00],
[ 1.02157170e-01, 9.08292529e-01],
[ 1.57124008e-01, 4.67867330e-01],
[-3.70141904e-01, 8.37819958e-01],
[-1.13070867e+00, 3.14355725e-01],
[-3.90677864e-01, 9.38671520e-01],
[-6.87884118e-01, 6.49208795e-01],
[ 1.24776506e+00, 2.76064726e-01],
[ 3.06591002e-01, 7.88036676e-01],
[ 2.04038992e-01, 6.80996038e-01],
[-7.58596949e-01, 8.47736316e-01],
[ 1.51252906e+00, -2.49050224e-01],
[ 5.12299082e-01, -4.66973456e-01],
[ 2.52725238e-01, 8.09531675e-02],
[ 1.68778165e+00, -9.39961803e-02],
[ 1.17415089e+00, -6.14249718e-01],
[-8.02333205e-01, 2.41514133e-01],
[ 1.00045411e+00, -4.37655582e-01],
[ 1.54642265e-01, -1.17407287e-02],
[-3.41296241e-01, 1.01462785e+00],
[ 1.02068565e+00, -5.69697235e-01],
[ 1.95744069e+00, 3.88628855e-01],
[ 5.26351944e-01, -5.22776873e-01],
[ 9.18282065e-01, 5.73472207e-01],
[ 5.99641871e-02, 1.02445939e+00],
[ 5.56728744e-01, -3.58381110e-01],
[ 5.39212315e-01, -5.34658968e-01],
[-5.70001262e-02, 1.07556367e+00],
[-2.22216673e-02, 2.23007753e-01],
[ 1.75776899e+00, 4.82419678e-02],
[ 1.09539838e+00, -4.56582981e-01],
[-8.72747894e-01, 5.02265810e-01],
[ 2.11709506e+00, -1.84640411e-03],
[ 1.89099918e+00, 2.36550534e-01],
[ 6.91281528e-01, -5.51648549e-01],
[ 1.62297481e+00, 1.00579634e-01],
[ 1.08160892e+00, -6.06687460e-01],
[ 1.84846370e+00, 2.28878795e-01],
[-5.92357213e-02, 3.91339052e-01],
[ 1.03636962e-01, 4.13595083e-01],
[-7.59511619e-01, 3.13446360e-01],
[ 1.99315526e+00, 3.15344396e-01],
[ 1.97059437e+00, 7.24447484e-02],
[ 2.00724866e+00, 4.51937406e-01],
[-9.78531843e-01, 4.10893830e-01],
[ 1.17389506e+00, 5.47833787e-02],
[ 4.04834809e-01, -5.75112600e-03],
[ 9.83598109e-01, 5.36254359e-01],
[ 1.69049384e-01, 1.05989627e+00],
[ 3.11892495e-01, 7.38275702e-01],
[-9.60104654e-01, -5.73605573e-02],
[-6.45470723e-01, 2.78027737e-01],
[ 1.47604530e+00, -4.93264046e-01],
[ 1.58990426e+00, -1.50849637e-01],
[ 9.71596746e-01, 1.80920134e-01],
[ 9.19160913e-01, -3.39387032e-01],
[-7.76978083e-01, 6.71030845e-01],
[ 1.33119034e+00, -3.67458615e-01],
[ 1.61388126e+00, -3.50550516e-01],
[ 5.65163894e-01, 9.90258871e-01],
[ 3.10582778e-01, 1.68972168e-01],
[ 1.91018757e+00, 2.11348398e-01],
[ 1.27705286e+00, -3.61096552e-01],
[ 4.88550570e-01, -2.74938169e-01],
[-4.64474478e-02, 3.24422391e-01],
[ 6.10161072e-01, 7.37635387e-01],
[ 4.32259721e-01, -2.70182673e-01],
[-8.94825161e-01, 2.84687782e-01],
[-5.03690538e-01, 8.02777589e-01],
[ 4.18028088e-01, 8.98471550e-01],
[ 1.05073250e+00, 5.10734562e-02]]),
array([1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1,
0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0,
0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0,
0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1,
0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0,
1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1,
1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0,
1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0,
0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1,
0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1,
0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0,
0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0,
0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0,
0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0,
0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0,
0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0,
0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0,
0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1,
1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0,
1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0]))
X, y = noisy_moons
X1 = []
X2 = []
for i in range(len(X)):
if y[i]:
X2.append(X[i])
else:
X1.append(X[i])
X1 = np.array(X1)
X2 = np.array(X2)
plt.scatter(X1[:,0],X1[:,1])
plt.scatter(X2[:,0],X2[:,1])
plt.show()
from numpy.linalg import inv, det
def parameteres_basedon_bayes_estimation(X, mean_0, covariance_0):
t_covariance = np.cov(X.T)
t_mean = np.array([np.mean(X[:,0]), np.mean(X[:,1])]).reshape(2,1)
covariance_n = (covariance_0/len(X)).dot(inv(covariance_0 + (t_covariance / len(X)))).dot(t_covariance)
mean_n = covariance_0.dot(inv(covariance_0 + (t_covariance / len(X)))).dot(t_mean) + (t_covariance.dot(inv(covariance_0 + (t_covariance / len(X)))).dot(mean_0))/len(X)
return mean_n, t_covariance + covariance_n
def gaussian_pdf(X, mean, covariance):
X = X.reshape(2,1)
mean = mean.reshape(2,1)
d = X.shape[0]
pdf = (np.exp(-(1/2) * ((X - mean).T.dot(inv(covariance)).dot(X-mean)))/((((2 * np.pi)**d) * det(covariance))**(1/2)))
return pdf
def draw_gaussian_plot(X,means,covariances,color):
x,y = np.meshgrid(np.sort(X[:,0]),np.sort(X[:,1]))
XY = np.array([x.flatten(),y.flatten()]).T
ax0.scatter(X[:,0],X[:,1],c = color)
for m,c in zip(means,covariances):
multi_normal_pdf = []
for xy in XY:
multi_normal_pdf.append(gaussian_pdf(xy, m, c)[0][0])
ax0.contour(np.sort(X[:,0]),np.sort(X[:,1]),np.array(multi_normal_pdf).reshape(len(X),len(X)),colors=color,alpha=0.3)
ax0.scatter(m[0],m[1],c='green',zorder=10,s=100)
mean1 = np.array([[0],[0.5]])
mean2 = np.array([[1],[0]])
covariance = np.array([[1,0.5],[0.5,0.5]])
mean_n_1, covariance_n_1 = parameteres_basedon_bayes_estimation(X1, mean1, covariance)
mean_n_2, covariance_n_2 = parameteres_basedon_bayes_estimation(X2, mean2, covariance)
fig = plt.figure(figsize=(10,6))
ax0 = fig.add_subplot(111)
draw_gaussian_plot(X1,[mean_n_1],[covariance_n_1],'red')
draw_gaussian_plot(X2,[mean_n_2],[covariance_n_2],'blue')
plt.show()
لینک ها: https://github.com/saniikakulkarni/Gaussian-Mixture-Model-from-scratch/blob/main/Gaussian_Mixture_Model_from_scratch.ipynb https://towardsdatascience.com/gaussian-mixture-models-implemented-from-scratch-1857e40ea566
class GMM:
def __init__(self, k, max_iter=5):
self.k = k
self.max_iter = int(max_iter)
self.means = []
self.covariances = []
self.alpha = np.full(shape=self.k, fill_value=1/self.k)
def e_step(self, X):
self.weights = self.predict_proba(X)
self.alpha = self.weights.mean(axis=0)
def m_step(self, X):
for i in range(self.k):
weight = self.weights[:, [i]]
total_weight = weight.sum()
self.means[i] = (X * weight).sum(axis=0) / total_weight
self.covariances[i] = np.cov(X.T,aweights=(weight/total_weight).flatten(), bias=True)
def fit(self, X):
self.n = len(X)
self.weights = np.full(shape=X.shape, fill_value=1/self.k)
random_row = randint(0, self.n, self.k)
init_cov = np.cov(X.T)
for i in range(self.k):
self.means.append(X[random_row[i]])
self.covariances.append(init_cov)
for iteration in range(self.max_iter):
self.e_step(X)
self.m_step(X)
def predict_proba(self, X):
weights = np.zeros((len(X), self.k))
for i in range(len(X)):
for j in range(self.k):
likelihood = gaussian_pdf(X[i], self.means[j], self.covariances[j])
weights[i][j] = self.alpha[j] * likelihood
weights[i][j] /= sum([self.alpha[l]*gaussian_pdf(X[i], self.means[l], self.covariances[l]) for l in range(self.k)])
return weights
def get_aic_bic(self, X):
l = 1
likelihood = np.zeros( (self.n, self.k) )
for i in range(len(X)):
for j in range(self.k):
likelihood[i,j] = gaussian_pdf(X[i], self.means[j], self.covariances[j])[0][0]
for i in likelihood:
l *= i.dot(self.alpha)
aic = -2 * np.log(l) + 2 * self.k
bic = -2 * np.log(l) + np.log(self.n) * self.k
return aic, bic
gmm_model1 = GMM(3)
gmm_model1.fit(X1)
gmm_model2 = GMM(3)
gmm_model2.fit(X2)
fig = plt.figure(figsize=(10,6))
ax0 = fig.add_subplot(111)
draw_gaussian_plot(X1,gmm_model1.means,gmm_model1.covariances,'red')
draw_gaussian_plot(X2,gmm_model2.means,gmm_model2.covariances,'blue')
plt.show()
gmm_model1 = GMM(8)
gmm_model1.fit(X1)
gmm_model2 = GMM(8)
gmm_model2.fit(X2)
fig = plt.figure(figsize=(10,6))
ax0 = fig.add_subplot(111)
draw_gaussian_plot(X1,gmm_model1.means,gmm_model1.covariances,'red')
draw_gaussian_plot(X2,gmm_model2.means,gmm_model2.covariances,'blue')
plt.show()
gmm_model1 = GMM(16)
gmm_model1.fit(X1)
gmm_model2 = GMM(16)
gmm_model2.fit(X2)
fig = plt.figure(figsize=(10,6))
ax0 = fig.add_subplot(111)
draw_gaussian_plot(X1,gmm_model1.means,gmm_model1.covariances,'red')
draw_gaussian_plot(X2,gmm_model2.means,gmm_model2.covariances,'blue')
plt.show()
def plot_aic_bic(X,K):
AIC = []
BIC = []
for k in K:
gmm_model = GMM(k)
gmm_model.fit(X)
aic, bic = gmm_model.get_aic_bic(X)
AIC.append(aic)
BIC.append(bic)
plt.plot(K, AIC, label='aic')
plt.plot(K, BIC, label='bic')
plt.title("AIC - BIC")
plt.legend()
plt.show()
sum_ab = [i + j for i,j in zip(AIC,BIC)]
best_k = sum_ab.index(min(sum_ab)) + 1
print("best k : ", best_k)
K = range(1,17)
plot_aic_bic(X1,K)
plot_aic_bic(X2,K)
best k : 6
best k : 9